From 4ce2782a61e23ed0326faac2ee97a9bd36ec8963 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 23 Mar 2015 23:41:06 -0700 Subject: [PATCH 001/171] [SPARK-6428] Added explicit types for all public methods in core. Author: Reynold Xin Closes #5125 from rxin/core-explicit-type and squashes the following commits: f471415 [Reynold Xin] Revert style checker changes. 81b66e4 [Reynold Xin] Code review feedback. a7533e3 [Reynold Xin] Mima excludes. 1d795f5 [Reynold Xin] [SPARK-6428] Added explicit types for all public methods in core. --- .../scala/org/apache/spark/Accumulators.scala | 23 +++-- .../scala/org/apache/spark/Dependency.scala | 6 +- .../scala/org/apache/spark/FutureAction.scala | 4 +- .../org/apache/spark/HeartbeatReceiver.scala | 2 +- .../org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/Partitioner.scala | 4 +- .../apache/spark/SerializableWritable.scala | 6 +- .../scala/org/apache/spark/SparkConf.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 48 +++++----- .../scala/org/apache/spark/TaskState.scala | 4 +- .../scala/org/apache/spark/TestUtils.scala | 2 +- .../org/apache/spark/api/java/JavaRDD.scala | 2 +- .../spark/api/java/JavaSparkContext.scala | 2 +- .../org/apache/spark/api/java/JavaUtils.scala | 35 ++++---- .../apache/spark/api/python/PythonRDD.scala | 39 ++++---- .../apache/spark/api/python/SerDeUtil.scala | 2 +- .../WriteInputFormatTestDataGenerator.scala | 27 +++--- .../apache/spark/broadcast/Broadcast.scala | 2 +- .../spark/broadcast/BroadcastManager.scala | 2 +- .../spark/broadcast/HttpBroadcast.scala | 4 +- .../broadcast/HttpBroadcastFactory.scala | 2 +- .../spark/broadcast/TorrentBroadcast.scala | 2 +- .../broadcast/TorrentBroadcastFactory.scala | 2 +- .../org/apache/spark/deploy/Client.scala | 4 +- .../apache/spark/deploy/ClientArguments.scala | 2 +- .../apache/spark/deploy/DeployMessage.scala | 2 +- .../spark/deploy/FaultToleranceTest.scala | 2 +- .../apache/spark/deploy/JsonProtocol.scala | 15 ++-- .../apache/spark/deploy/SparkHadoopUtil.scala | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../spark/deploy/SparkSubmitArguments.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../deploy/history/FsHistoryProvider.scala | 4 +- .../spark/deploy/history/HistoryServer.scala | 10 +-- .../master/FileSystemPersistenceEngine.scala | 2 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../deploy/master/RecoveryModeFactory.scala | 15 +++- .../spark/deploy/master/WorkerInfo.scala | 2 +- .../master/ZooKeeperPersistenceEngine.scala | 2 +- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../spark/deploy/worker/DriverRunner.scala | 22 +++-- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 2 +- .../apache/spark/executor/ExecutorActor.scala | 2 +- .../apache/spark/executor/TaskMetrics.scala | 90 +++++++++---------- .../spark/input/PortableDataStream.scala | 17 ++-- .../spark/mapred/SparkHadoopMapRedUtil.scala | 2 +- .../mapreduce/SparkHadoopMapReduceUtil.scala | 2 +- .../apache/spark/metrics/MetricsSystem.scala | 3 +- .../spark/metrics/sink/MetricsServlet.scala | 18 ++-- .../org/apache/spark/metrics/sink/Sink.scala | 4 +- .../spark/network/nio/BlockMessageArray.scala | 6 +- .../spark/network/nio/BufferMessage.scala | 10 +-- .../apache/spark/network/nio/Connection.scala | 8 +- .../spark/network/nio/ConnectionId.scala | 4 +- .../spark/network/nio/ConnectionManager.scala | 2 +- .../network/nio/ConnectionManagerId.scala | 2 +- .../apache/spark/network/nio/Message.scala | 6 +- .../spark/network/nio/MessageChunk.scala | 6 +- .../network/nio/MessageChunkHeader.scala | 4 +- .../apache/spark/partial/PartialResult.scala | 2 +- .../org/apache/spark/rdd/CartesianRDD.scala | 2 +- .../org/apache/spark/rdd/CoalescedRDD.scala | 10 +-- .../org/apache/spark/rdd/HadoopRDD.scala | 13 ++- .../scala/org/apache/spark/rdd/JdbcRDD.scala | 10 ++- .../apache/spark/rdd/MapPartitionsRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 2 +- .../spark/rdd/ParallelCollectionRDD.scala | 2 +- .../spark/rdd/PartitionPruningRDD.scala | 10 ++- .../scala/org/apache/spark/rdd/PipedRDD.scala | 8 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 12 +-- .../org/apache/spark/rdd/ShuffledRDD.scala | 4 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 2 +- .../scala/org/apache/spark/rdd/UnionRDD.scala | 2 +- .../spark/rdd/ZippedPartitionsRDD.scala | 2 +- .../spark/scheduler/AccumulableInfo.scala | 7 +- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../scheduler/EventLoggingListener.scala | 61 ++++++++----- .../apache/spark/scheduler/JobLogger.scala | 2 +- .../apache/spark/scheduler/JobWaiter.scala | 2 +- .../scheduler/OutputCommitCoordinator.scala | 2 +- .../apache/spark/scheduler/ResultTask.scala | 2 +- .../spark/scheduler/ShuffleMapTask.scala | 4 +- .../spark/scheduler/SparkListener.scala | 4 +- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../apache/spark/scheduler/TaskLocation.scala | 19 ++-- .../spark/scheduler/TaskSchedulerImpl.scala | 16 ++-- .../spark/scheduler/TaskSetManager.scala | 13 +-- .../CoarseGrainedSchedulerBackend.scala | 2 +- .../cluster/YarnSchedulerBackend.scala | 2 +- .../scheduler/cluster/mesos/MemoryUtils.scala | 2 +- .../cluster/mesos/MesosSchedulerBackend.scala | 2 +- .../spark/scheduler/local/LocalBackend.scala | 4 +- .../spark/serializer/JavaSerializer.scala | 5 +- .../spark/serializer/KryoSerializer.scala | 2 +- .../shuffle/FileShuffleBlockManager.scala | 4 +- .../shuffle/IndexShuffleBlockManager.scala | 4 +- .../org/apache/spark/storage/BlockId.scala | 34 +++---- .../apache/spark/storage/BlockManagerId.scala | 8 +- .../spark/storage/BlockManagerMaster.scala | 2 +- .../storage/BlockManagerMasterActor.scala | 4 +- .../storage/BlockManagerSlaveActor.scala | 2 +- .../spark/storage/BlockObjectWriter.scala | 17 ++-- .../apache/spark/storage/FileSegment.scala | 4 +- .../org/apache/spark/storage/RDDInfo.scala | 4 +- .../apache/spark/storage/StorageLevel.scala | 16 ++-- .../spark/storage/StorageStatusListener.scala | 7 +- .../spark/storage/TachyonFileSegment.scala | 4 +- .../scala/org/apache/spark/ui/SparkUI.scala | 2 +- .../scala/org/apache/spark/ui/UIUtils.scala | 6 +- .../apache/spark/ui/UIWorkloadGenerator.scala | 4 +- .../apache/spark/ui/exec/ExecutorsTab.scala | 10 +-- .../spark/ui/jobs/JobProgressListener.scala | 14 +-- .../org/apache/spark/ui/jobs/JobsTab.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 10 ++- .../org/apache/spark/ui/jobs/StagesTab.scala | 6 +- .../org/apache/spark/ui/jobs/UIData.scala | 10 +-- .../apache/spark/ui/storage/StorageTab.scala | 12 +-- .../spark/util/CompletionIterator.scala | 10 +-- .../org/apache/spark/util/Distribution.scala | 2 +- .../org/apache/spark/util/ManualClock.scala | 30 +++---- .../apache/spark/util/MetadataCleaner.scala | 7 +- .../org/apache/spark/util/MutablePair.scala | 2 +- .../apache/spark/util/ParentClassLoader.scala | 2 +- .../spark/util/SerializableBuffer.scala | 2 +- .../org/apache/spark/util/StatCounter.scala | 4 +- .../util/TimeStampedWeakValueHashMap.scala | 6 +- .../scala/org/apache/spark/util/Utils.scala | 30 +++---- .../apache/spark/util/collection/BitSet.scala | 4 +- .../collection/ExternalAppendOnlyMap.scala | 4 +- .../util/collection/ExternalSorter.scala | 2 +- .../spark/util/collection/OpenHashMap.scala | 6 +- .../spark/util/collection/OpenHashSet.scala | 4 +- .../collection/PrimitiveKeyOpenHashMap.scala | 8 +- .../apache/spark/util/collection/Utils.scala | 2 +- .../spark/util/logging/FileAppender.scala | 4 +- .../spark/util/random/RandomSampler.scala | 44 +++++---- .../util/random/StratifiedSamplingUtils.scala | 2 +- .../spark/util/random/XORShiftRandom.scala | 2 +- project/MimaExcludes.scala | 8 +- scalastyle-config.xml | 4 +- 142 files changed, 597 insertions(+), 526 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index bcf832467f00b..330df1d59a9b1 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -18,8 +18,6 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} -import java.util.concurrent.atomic.AtomicLong -import java.lang.ThreadLocal import scala.collection.generic.Growable import scala.collection.mutable.Map @@ -109,7 +107,7 @@ class Accumulable[R, T] ( * The typical use of this method is to directly mutate the local value, eg., to add * an element to a Set. */ - def localValue = value_ + def localValue: R = value_ /** * Set the accumulator's value; only allowed on master. @@ -137,7 +135,7 @@ class Accumulable[R, T] ( Accumulators.register(this, false) } - override def toString = if (value_ == null) "null" else value_.toString + override def toString: String = if (value_ == null) "null" else value_.toString } /** @@ -257,22 +255,22 @@ object AccumulatorParam { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } implicit object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // TODO: Add AccumulatorParams for other types, e.g. lists and strings @@ -351,6 +349,7 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) - def stringifyValue(value: Any) = "%s".format(value) + def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) + + def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 9a7cd4523e5ab..fc8cdde9348ee 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -74,7 +74,7 @@ class ShuffleDependency[K, V, C]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { - override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] + override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] val shuffleId: Int = _rdd.context.newShuffleId() @@ -91,7 +91,7 @@ class ShuffleDependency[K, V, C]( */ @DeveloperApi class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = List(partitionId) + override def getParents(partitionId: Int): List[Int] = List(partitionId) } @@ -107,7 +107,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { if (partitionId >= outStart && partitionId < outStart + length) { List(partitionId - outStart + inStart) } else { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e97a7375a267b..91f9ef8ce7185 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -168,7 +168,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } } - def jobIds = Seq(jobWaiter.jobId) + def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } @@ -276,7 +276,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds = jobs + def jobIds: Seq[Int] = jobs } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 69178da1a7773..715f292f03469 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -65,7 +65,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule super.preStart() } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val unknownExecutor = !scheduler.executorHeartbeatReceived( executorId, taskMetrics, blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6e4edc7c80d7a..c9426c5de23a2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -43,7 +43,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster extends Actor with ActorLogReceive with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e53a78ead2c0e..b8d244408bc5b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -76,7 +76,7 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { case null => 0 @@ -154,7 +154,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } - def numPartitions = rangeBounds.length + 1 + def numPartitions: Int = rangeBounds.length + 1 private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index 55cb25946c2ad..cb2cae185256a 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -28,8 +28,10 @@ import org.apache.spark.util.Utils @DeveloperApi class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { - def value = t - override def toString = t.toString + + def value: T = t + + override def toString: String = t.toString private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { out.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 2ca19f53d2f07..0c123c96b8d7b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -133,7 +133,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Set multiple parameters together */ - def setAll(settings: Traversable[(String, String)]) = { + def setAll(settings: Traversable[(String, String)]): SparkConf = { this.settings.putAll(settings.toMap.asJava) this } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 228ff715fe7cb..a70be16f77eeb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -986,7 +986,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli union(Seq(first) ++ rest) /** Get an RDD that has no partitions or elements. */ - def emptyRDD[T: ClassTag] = new EmptyRDD[T](this) + def emptyRDD[T: ClassTag]: EmptyRDD[T] = new EmptyRDD[T](this) // Methods for creating shared variables @@ -994,7 +994,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = { val acc = new Accumulator(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) @@ -1006,7 +1006,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the * driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) + : Accumulator[T] = { val acc = new Accumulator(initialValue, param, Some(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1018,7 +1019,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = { + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1031,7 +1033,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = { + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { val acc = new Accumulable(initialValue, param, Some(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1209,7 +1212,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) /** The version of Spark on which this application is running. */ - def version = SPARK_VERSION + def version: String = SPARK_VERSION /** * Return a map from the slave to the max memory available for caching and the remaining @@ -1659,7 +1662,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - def getCheckpointDir = checkpointDir + def getCheckpointDir: Option[String] = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ def defaultParallelism: Int = { @@ -1900,28 +1903,28 @@ object SparkContext extends Logging { "backward compatibility.", "1.3.0") object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // The following deprecated functions have already been moved to `object RDD` to @@ -1931,18 +1934,18 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairRDDFunctions[K, V] = RDD.rddToPairRDDFunctions(rdd) - } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]): AsyncRDDActions[T] = + RDD.rddToAsyncRDDActions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]) = { + rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { val kf = implicitly[K => Writable] val vf = implicitly[V => Writable] // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it @@ -1954,16 +1957,17 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]): OrderedRDDFunctions[K, V, (K, V)] = RDD.rddToOrderedRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]): DoubleRDDFunctions = + RDD.doubleRDDToDoubleRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]): DoubleRDDFunctions = RDD.numericRDDToDoubleRDDFunctions(rdd) // The following deprecated functions have already been moved to `object WritableFactory` to @@ -2134,7 +2138,7 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_N_REGEX(threads) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. val threadCount = if (threads == "*") localCpuCount else threads.toInt if (threadCount <= 0) { @@ -2146,7 +2150,7 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*, M] means the number of cores on the computer with M failures // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index c415fe99b105e..fe19f07e32d1b 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -27,9 +27,9 @@ private[spark] object TaskState extends Enumeration { type TaskState = Value - def isFailed(state: TaskState) = (LOST == state) || (FAILED == state) + def isFailed(state: TaskState): Boolean = (LOST == state) || (FAILED == state) - def isFinished(state: TaskState) = FINISHED_STATES.contains(state) + def isFinished(state: TaskState): Boolean = FINISHED_STATES.contains(state) def toMesos(state: TaskState): MesosTaskState = state match { case LAUNCHING => MesosTaskState.TASK_STARTING diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 35b324ba6f573..398ca41e16151 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -107,7 +107,7 @@ private[spark] object TestUtils { private class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { - override def getCharContent(ignoreEncodingErrors: Boolean) = code + override def getCharContent(ignoreEncodingErrors: Boolean): String = code } /** Creates a compiled class with the given name. Class file will be placed in destDir. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 3e9beb670f7ad..18ccd625fc8d1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -179,7 +179,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - override def toString = rdd.toString + override def toString: String = rdd.toString /** Assign a name to this RDD */ def setName(name: String): JavaRDD[T] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 6d6ed693be752..3be6783bba49d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -108,7 +108,7 @@ class JavaSparkContext(val sc: SparkContext) private[spark] val env = sc.env - def statusTracker = new JavaSparkStatusTracker(sc) + def statusTracker: JavaSparkStatusTracker = new JavaSparkStatusTracker(sc) def isLocal: java.lang.Boolean = sc.isLocal diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 71b26737b8c02..8f9647eea9e25 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.api.java +import java.util.Map.Entry + import com.google.common.base.Optional import java.{util => ju} @@ -30,8 +32,8 @@ private[spark] object JavaUtils { } // Workaround for SPARK-3926 / SI-8911 - def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = - new SerializableMapWrapper(underlying) + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]): SerializableMapWrapper[A, B] + = new SerializableMapWrapper(underlying) // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper, // but implements java.io.Serializable. It can't just be subclassed to make it @@ -40,36 +42,33 @@ private[spark] object JavaUtils { class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) extends ju.AbstractMap[A, B] with java.io.Serializable { self => - override def size = underlying.size + override def size: Int = underlying.size override def get(key: AnyRef): B = try { - underlying get key.asInstanceOf[A] match { - case None => null.asInstanceOf[B] - case Some(v) => v - } + underlying.getOrElse(key.asInstanceOf[A], null.asInstanceOf[B]) } catch { case ex: ClassCastException => null.asInstanceOf[B] } override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] { - def size = self.size + override def size: Int = self.size - def iterator = new ju.Iterator[ju.Map.Entry[A, B]] { + override def iterator: ju.Iterator[ju.Map.Entry[A, B]] = new ju.Iterator[ju.Map.Entry[A, B]] { val ui = underlying.iterator var prev : Option[A] = None - def hasNext = ui.hasNext + def hasNext: Boolean = ui.hasNext - def next() = { - val (k, v) = ui.next + def next(): Entry[A, B] = { + val (k, v) = ui.next() prev = Some(k) new ju.Map.Entry[A, B] { import scala.util.hashing.byteswap32 - def getKey = k - def getValue = v - def setValue(v1 : B) = self.put(k, v1) - override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) - override def equals(other: Any) = other match { + override def getKey: A = k + override def getValue: B = v + override def setValue(v1 : B): B = self.put(k, v1) + override def hashCode: Int = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) + override def equals(other: Any): Boolean = other match { case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue case _ => false } @@ -81,7 +80,7 @@ private[spark] object JavaUtils { case Some(k) => underlying match { case mm: mutable.Map[A, _] => - mm remove k + mm.remove(k) prev = None case _ => throw new UnsupportedOperationException("remove") diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4c71b69069eb3..19f4c95fcad74 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -54,9 +54,11 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = firstParent.partitions + override def getPartitions: Array[Partition] = firstParent.partitions - override val partitioner = if (preservePartitoning) firstParent.partitioner else None + override val partitioner: Option[Partitioner] = { + if (preservePartitoning) firstParent.partitioner else None + } override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -92,7 +94,7 @@ private[spark] class PythonRDD( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = new Iterator[Array[Byte]] { - def next(): Array[Byte] = { + override def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { _nextObj = read() @@ -175,7 +177,7 @@ private[spark] class PythonRDD( var _nextObj = read() - def hasNext = _nextObj != null + override def hasNext: Boolean = _nextObj != null } new InterruptibleIterator(context, stdoutIterator) } @@ -303,11 +305,10 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends - RDD[(Long, Array[Byte])](prev) { - override def getPartitions = prev.partitions - override val partitioner = prev.partitioner - override def compute(split: Partition, context: TaskContext) = +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { + override def getPartitions: Array[Partition] = prev.partitions + override val partitioner: Option[Partitioner] = prev.partitioner + override def compute(split: Partition, context: TaskContext): Iterator[(Long, Array[Byte])] = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (Utils.deserializeLongValue(a), b) case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) @@ -435,7 +436,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, minSplits: Int, - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] @@ -462,7 +463,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, @@ -488,7 +489,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, @@ -505,7 +506,7 @@ private[spark] object PythonRDD extends Logging { inputFormatClass: String, keyClass: String, valueClass: String, - conf: Configuration) = { + conf: Configuration): RDD[(K, V)] = { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] @@ -531,7 +532,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = hadoopRDDFromClassNames[K, V, F](sc, @@ -557,7 +558,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = hadoopRDDFromClassNames[K, V, F](sc, @@ -686,7 +687,7 @@ private[spark] object PythonRDD extends Logging { pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, path: String, - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { saveAsHadoopFile( pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat", null, null, null, null, new java.util.HashMap(), compressionCodecClass) @@ -711,7 +712,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -741,7 +742,7 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String]): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -766,7 +767,7 @@ private[spark] object PythonRDD extends Logging { confAsMap: java.util.HashMap[String, String], keyConverterClass: String, valueConverterClass: String, - useNewAPI: Boolean) = { + useNewAPI: Boolean): Unit = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized), keyConverterClass, valueConverterClass, new JavaToWritableConverter) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index fb52a960e0765..257491e90dd66 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -84,7 +84,7 @@ private[spark] object SerDeUtil extends Logging { private var initialized = false // This should be called before trying to unpickle array.array from Python // In cluster mode, this should be put in closure - def initialize() = { + def initialize(): Unit = { synchronized{ if (!initialized) { Unpickler.registerConstructor("array", "array", new ArrayConstructor()) diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index cf289fb3ae39f..8f30ff9202c83 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -18,38 +18,37 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} +import java.{util => ju} import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat + +import org.apache.spark.SparkException import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.{SparkContext, SparkException} /** * A class to test Pyrolite serialization on the Scala side, that will be deserialized * in Python - * @param str - * @param int - * @param double */ case class TestWritable(var str: String, var int: Int, var double: Double) extends Writable { def this() = this("", 0, 0.0) - def getStr = str + def getStr: String = str def setStr(str: String) { this.str = str } - def getInt = int + def getInt: Int = int def setInt(int: Int) { this.int = int } - def getDouble = double + def getDouble: Double = double def setDouble(double: Double) { this.double = double } - def write(out: DataOutput) = { + def write(out: DataOutput): Unit = { out.writeUTF(str) out.writeInt(int) out.writeDouble(double) } - def readFields(in: DataInput) = { + def readFields(in: DataInput): Unit = { str = in.readUTF() int = in.readInt() double = in.readDouble() @@ -57,28 +56,28 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } private[python] class TestInputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Char = { obj.asInstanceOf[IntWritable].get().toChar } } private[python] class TestInputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) } } private[python] class TestOutputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Text = { new Text(obj.asInstanceOf[Int].toString) } } private[python] class TestOutputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): DoubleWritable = { new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) } } @@ -86,7 +85,7 @@ private[python] class TestOutputValueConverter extends Converter[Any, Any] { private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable]) private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] { - override def convert(obj: Any) = obj match { + override def convert(obj: Any): DoubleArrayWritable = obj match { case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] => val daw = new DoubleArrayWritable daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_))) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index a5ea478f231d7..12d79f6ed311b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -146,5 +146,5 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo } } - override def toString = "Broadcast(" + id + ")" + override def toString: String = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 8f8a0b11f9f2e..685313ac009ba 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -58,7 +58,7 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = { + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 1444c0dd3d2d6..74ccfa6d3c9a3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -160,7 +160,7 @@ private[broadcast] object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name) private def write(id: Long, value: Any) { val file = getFile(id) @@ -222,7 +222,7 @@ private[broadcast] object HttpBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver * and delete the associated broadcast file. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) if (removeFromDriver) { val file = getFile(id) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c7ef02d572a19..cf3ae36f27949 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -31,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = new HttpBroadcast[T](value_, isLocal, id) override def stop() { HttpBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 94142d33369c7..23b02e60338fb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -222,7 +222,7 @@ private object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { logDebug(s"Unpersisting TorrentBroadcast $id") SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index fb024c12094f2..96d8dd79908c8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -30,7 +30,7 @@ class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = { + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { new TorrentBroadcast[T](value_, id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 237d26fc6bd0e..65238af2caa24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -38,7 +38,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) var masterActor: ActorSelection = _ val timeout = AkkaUtils.askTimeout(conf) - override def preStart() = { + override def preStart(): Unit = { masterActor = context.actorSelection( Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system))) @@ -118,7 +118,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case SubmitDriverResponse(success, driverId, message) => println(message) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 53bc62aff7395..5cbac787dceeb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -42,7 +42,7 @@ private[deploy] class ClientArguments(args: Array[String]) { var memory: Int = DEFAULT_MEMORY var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() - def driverOptions = _driverOptions.toSeq + def driverOptions: Seq[String] = _driverOptions.toSeq // kill parameters var driverId: String = "" diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 7f600d89604a2..0997507d016f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -162,7 +162,7 @@ private[deploy] object DeployMessages { Utils.checkHost(host, "Required hostname") assert (port > 0) - def uri = "spark://" + host + ":" + port + def uri: String = "spark://" + host + ":" + port def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p } } diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 5668b53fc6f4f..a7c89276a045e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -426,7 +426,7 @@ private object SparkDocker { } private class DockerId(val id: String) { - override def toString = id + override def toString: String = id } private object Docker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 458a7c3a455de..dfc5b97e6a6c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} @@ -24,7 +25,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.worker.ExecutorRunner private[deploy] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo) = { + def writeWorkerInfo(obj: WorkerInfo): JObject = { ("id" -> obj.id) ~ ("host" -> obj.host) ~ ("port" -> obj.port) ~ @@ -39,7 +40,7 @@ private[deploy] object JsonProtocol { ("lastheartbeat" -> obj.lastHeartbeat) } - def writeApplicationInfo(obj: ApplicationInfo) = { + def writeApplicationInfo(obj: ApplicationInfo): JObject = { ("starttime" -> obj.startTime) ~ ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ @@ -51,7 +52,7 @@ private[deploy] object JsonProtocol { ("duration" -> obj.duration) } - def writeApplicationDescription(obj: ApplicationDescription) = { + def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores) ~ ("memoryperslave" -> obj.memoryPerSlave) ~ @@ -59,14 +60,14 @@ private[deploy] object JsonProtocol { ("command" -> obj.command.toString) } - def writeExecutorRunner(obj: ExecutorRunner) = { + def writeExecutorRunner(obj: ExecutorRunner): JObject = { ("id" -> obj.execId) ~ ("memory" -> obj.memory) ~ ("appid" -> obj.appId) ~ ("appdesc" -> writeApplicationDescription(obj.appDesc)) } - def writeDriverInfo(obj: DriverInfo) = { + def writeDriverInfo(obj: DriverInfo): JObject = { ("id" -> obj.id) ~ ("starttime" -> obj.startTime.toString) ~ ("state" -> obj.state.toString) ~ @@ -74,7 +75,7 @@ private[deploy] object JsonProtocol { ("memory" -> obj.desc.mem) } - def writeMasterState(obj: MasterStateResponse) = { + def writeMasterState(obj: MasterStateResponse): JObject = { ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ ("cores" -> obj.workers.map(_.cores).sum) ~ @@ -87,7 +88,7 @@ private[deploy] object JsonProtocol { ("status" -> obj.status.toString) } - def writeWorkerState(obj: WorkerStateResponse) = { + def writeWorkerState(obj: WorkerStateResponse): JObject = { ("id" -> obj.workerId) ~ ("masterurl" -> obj.masterUrl) ~ ("masterwebuiurl" -> obj.masterWebUiUrl) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e0a32fb65cd51..c2568eb4b60ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -193,7 +193,7 @@ class SparkHadoopUtil extends Logging { * that file. */ def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { - def recurse(path: Path) = { + def recurse(path: Path): Array[FileStatus] = { val (directories, leaves) = fs.listStatus(path).partition(_.isDir) leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4f506be63fe59..660307d19eab4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -777,7 +777,7 @@ private[deploy] object SparkSubmitUtils { } /** A nice function to use in tests as well. Values are dummy strings. */ - def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( + def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 2250d5a28e4ef..6eb73c43470a5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -252,7 +252,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S master.startsWith("spark://") && deployMode == "cluster" } - override def toString = { + override def toString: String = { s"""Parsed arguments: | master $master | deployMode $deployMode diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 2d24083a77b73..3b729725257ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -116,7 +116,7 @@ private[spark] class AppClient( masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index db7c499661319..80c9c13ddec1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -93,7 +93,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def getRunner(operateFun: () => Unit): Runnable = { new Runnable() { - override def run() = Utils.tryOrExit { + override def run(): Unit = Utils.tryOrExit { operateFun() } } @@ -141,7 +141,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } - override def getListing() = applications.values + override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values override def getAppUI(appId: String): Option[SparkUI] = { try { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index af483d560b33e..72f6048239297 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -61,7 +61,7 @@ class HistoryServer( private val appCache = CacheBuilder.newBuilder() .maximumSize(retainedApplications) .removalListener(new RemovalListener[String, SparkUI] { - override def onRemoval(rm: RemovalNotification[String, SparkUI]) = { + override def onRemoval(rm: RemovalNotification[String, SparkUI]): Unit = { detachSparkUI(rm.getValue()) } }) @@ -149,14 +149,14 @@ class HistoryServer( * * @return List of all known applications. */ - def getApplicationList() = provider.getListing() + def getApplicationList(): Iterable[ApplicationHistoryInfo] = provider.getListing() /** * Returns the provider configuration to show in the listing page. * * @return A map with the provider's configuration. */ - def getProviderConfig() = provider.getConfig() + def getProviderConfig(): Map[String, String] = provider.getConfig() } @@ -195,9 +195,7 @@ object HistoryServer extends Logging { server.bind() Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { - override def run() = { - server.stop() - } + override def run(): Unit = server.stop() }) // Wait until the end of the world... or if the HistoryServer process is manually stopped diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index d2d30bfd7fcba..32499b3a784a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -48,7 +48,7 @@ private[master] class FileSystemPersistenceEngine( new File(dir + File.separator + name).delete() } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) files.map(deserializeFromFile[T]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1b42121c8db05..80506621f4d24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -204,7 +204,7 @@ private[master] class Master( self ! RevokedLeadership } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 1583bf1f60032..351db8fab2041 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -51,20 +51,27 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial */ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") - def createPersistenceEngine() = { + def createPersistenceEngine(): PersistenceEngine = { logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) } - def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { + new MonarchyLeaderAgent(master) + } } private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) extends StandaloneRecoveryModeFactory(conf, serializer) { - def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) - def createLeaderElectionAgent(master: LeaderElectable) = + def createPersistenceEngine(): PersistenceEngine = { + new ZooKeeperPersistenceEngine(conf, serializer) + } + + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { new ZooKeeperLeaderElectionAgent(master, conf) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index e94aae93e4495..9b3d48c6edc84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -104,7 +104,7 @@ private[spark] class WorkerInfo( "http://" + this.publicAddress + ":" + this.webUiPort } - def setState(state: WorkerState.Value) = { + def setState(state: WorkerState.Value): Unit = { this.state = state } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 1ac6677ad2b6d..a285783f72000 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -46,7 +46,7 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat zk.delete().forPath(WORKING_DIR + "/" + name) } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) file.map(deserializeFromFile[T]).flatten } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index dee2e4a447c6e..46509e39c0f23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -95,7 +95,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { // For now we only show driver information if the user has submitted drivers to the cluster. // This is until we integrate the notion of drivers and applications in the UI. - def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0 + def hasDrivers: Boolean = activeDrivers.length > 0 || completedDrivers.length > 0 val content =
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 27a9eabb1ede7..e0948e16ef354 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -56,8 +56,14 @@ private[deploy] class DriverRunner( private var finalExitCode: Option[Int] = None // Decoupled for testing - def setClock(_clock: Clock) = clock = _clock - def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper + def setClock(_clock: Clock): Unit = { + clock = _clock + } + + def setSleeper(_sleeper: Sleeper): Unit = { + sleeper = _sleeper + } + private var clock: Clock = new SystemClock() private var sleeper = new Sleeper { def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) @@ -155,7 +161,7 @@ private[deploy] class DriverRunner( private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) { builder.directory(baseDir) - def initialize(process: Process) = { + def initialize(process: Process): Unit = { // Redirect stdout and stderr to files val stdout = new File(baseDir, "stdout") CommandUtils.redirectStream(process.getInputStream, stdout) @@ -169,8 +175,8 @@ private[deploy] class DriverRunner( runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) } - def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit, - supervise: Boolean) { + def runCommandWithRetry( + command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Unit = { // Time to wait between submission retries. var waitSeconds = 1 // A run of this many seconds resets the exponential back-off. @@ -216,8 +222,8 @@ private[deploy] trait ProcessBuilderLike { } private[deploy] object ProcessBuilderLike { - def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike { - def start() = processBuilder.start() - def command = processBuilder.command() + def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { + override def start(): Process = processBuilder.start() + override def command: Seq[String] = processBuilder.command() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c1b0a295f9f74..c4c24a7866aa3 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -275,7 +275,7 @@ private[worker] class Worker( } } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 09d866fb0cd90..e0790274d7d3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -50,7 +50,7 @@ private[spark] class WorkerWatcher(workerUrl: String) private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index dd19e4947db1e..b5205d4e997ae 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -62,7 +62,7 @@ private[spark] class CoarseGrainedExecutorBackend( .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index 41925f7e97e84..3e47d13f7545d 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -33,7 +33,7 @@ private[spark] case object TriggerThreadDump private[spark] class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case TriggerThreadDump => sender ! Utils.getThreadDump() } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 07b152651dedf..06152f16ae618 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,13 +17,10 @@ package org.apache.spark.executor -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark.executor.DataReadMethod.DataReadMethod - import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} /** @@ -44,14 +41,14 @@ class TaskMetrics extends Serializable { * Host's name the task runs on */ private var _hostname: String = _ - def hostname = _hostname + def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ - def executorDeserializeTime = _executorDeserializeTime + def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value @@ -59,14 +56,14 @@ class TaskMetrics extends Serializable { * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ - def executorRunTime = _executorRunTime + def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value /** * The number of bytes this task transmitted back to the driver as the TaskResult */ private var _resultSize: Long = _ - def resultSize = _resultSize + def resultSize: Long = _resultSize private[spark] def setResultSize(value: Long) = _resultSize = value @@ -74,31 +71,31 @@ class TaskMetrics extends Serializable { * Amount of time the JVM spent in garbage collection while executing this task */ private var _jvmGCTime: Long = _ - def jvmGCTime = _jvmGCTime + def jvmGCTime: Long = _jvmGCTime private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value /** * Amount of time spent serializing the task result */ private var _resultSerializationTime: Long = _ - def resultSerializationTime = _resultSerializationTime + def resultSerializationTime: Long = _resultSerializationTime private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value /** * The number of in-memory bytes spilled by this task */ private var _memoryBytesSpilled: Long = _ - def memoryBytesSpilled = _memoryBytesSpilled - private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value - private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value + def memoryBytesSpilled: Long = _memoryBytesSpilled + private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value + private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value /** * The number of on-disk bytes spilled by this task */ private var _diskBytesSpilled: Long = _ - def diskBytesSpilled = _diskBytesSpilled - def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value + def diskBytesSpilled: Long = _diskBytesSpilled + def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read @@ -106,7 +103,7 @@ class TaskMetrics extends Serializable { */ private var _inputMetrics: Option[InputMetrics] = None - def inputMetrics = _inputMetrics + def inputMetrics: Option[InputMetrics] = _inputMetrics /** * This should only be used when recreating TaskMetrics, not when updating input metrics in @@ -128,7 +125,7 @@ class TaskMetrics extends Serializable { */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None - def shuffleReadMetrics = _shuffleReadMetrics + def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics /** * This should only be used when recreating TaskMetrics, not when updating read metrics in @@ -177,17 +174,18 @@ class TaskMetrics extends Serializable { * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, * we can store all the different inputMetrics (one per readMethod). */ - private[spark] def getInputMetricsForReadMethod( - readMethod: DataReadMethod): InputMetrics = synchronized { - _inputMetrics match { - case None => - val metrics = new InputMetrics(readMethod) - _inputMetrics = Some(metrics) - metrics - case Some(metrics @ InputMetrics(method)) if method == readMethod => - metrics - case Some(InputMetrics(method)) => - new InputMetrics(readMethod) + private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = { + synchronized { + _inputMetrics match { + case None => + val metrics = new InputMetrics(readMethod) + _inputMetrics = Some(metrics) + metrics + case Some(metrics @ InputMetrics(method)) if method == readMethod => + metrics + case Some(InputMetrics(method)) => + new InputMetrics(readMethod) + } } } @@ -256,14 +254,14 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ private var _bytesRead: Long = _ def bytesRead: Long = _bytesRead - def incBytesRead(bytes: Long) = _bytesRead += bytes + def incBytesRead(bytes: Long): Unit = _bytesRead += bytes /** * Total records read. */ private var _recordsRead: Long = _ def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long) = _recordsRead += records + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. @@ -293,15 +291,15 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) { * Total bytes written */ private var _bytesWritten: Long = _ - def bytesWritten = _bytesWritten - private[spark] def setBytesWritten(value : Long) = _bytesWritten = value + def bytesWritten: Long = _bytesWritten + private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value /** * Total records written */ private var _recordsWritten: Long = 0L - def recordsWritten = _recordsWritten - private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value + def recordsWritten: Long = _recordsWritten + private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value } /** @@ -314,7 +312,7 @@ class ShuffleReadMetrics extends Serializable { * Number of remote blocks fetched in this shuffle by this task */ private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched = _remoteBlocksFetched + def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value @@ -322,7 +320,7 @@ class ShuffleReadMetrics extends Serializable { * Number of local blocks fetched in this shuffle by this task */ private var _localBlocksFetched: Int = _ - def localBlocksFetched = _localBlocksFetched + def localBlocksFetched: Int = _localBlocksFetched private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value @@ -332,7 +330,7 @@ class ShuffleReadMetrics extends Serializable { * still not finished processing block A, it is not considered to be blocking on block B. */ private var _fetchWaitTime: Long = _ - def fetchWaitTime = _fetchWaitTime + def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value @@ -340,7 +338,7 @@ class ShuffleReadMetrics extends Serializable { * Total number of remote bytes read from the shuffle by this task */ private var _remoteBytesRead: Long = _ - def remoteBytesRead = _remoteBytesRead + def remoteBytesRead: Long = _remoteBytesRead private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value @@ -348,24 +346,24 @@ class ShuffleReadMetrics extends Serializable { * Shuffle data that was read from the local disk (as opposed to from a remote executor). */ private var _localBytesRead: Long = _ - def localBytesRead = _localBytesRead + def localBytesRead: Long = _localBytesRead private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value /** * Total bytes fetched in the shuffle by this task (both remote and local). */ - def totalBytesRead = _remoteBytesRead + _localBytesRead + def totalBytesRead: Long = _remoteBytesRead + _localBytesRead /** * Number of blocks fetched in this shuffle by this task (remote or local) */ - def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched + def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched /** * Total number of records read from the shuffle by this task */ private var _recordsRead: Long = _ - def recordsRead = _recordsRead + def recordsRead: Long = _recordsRead private[spark] def incRecordsRead(value: Long) = _recordsRead += value private[spark] def decRecordsRead(value: Long) = _recordsRead -= value } @@ -380,7 +378,7 @@ class ShuffleWriteMetrics extends Serializable { * Number of bytes written for the shuffle by this task */ @volatile private var _shuffleBytesWritten: Long = _ - def shuffleBytesWritten = _shuffleBytesWritten + def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value @@ -388,7 +386,7 @@ class ShuffleWriteMetrics extends Serializable { * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @volatile private var _shuffleWriteTime: Long = _ - def shuffleWriteTime= _shuffleWriteTime + def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value @@ -396,7 +394,7 @@ class ShuffleWriteMetrics extends Serializable { * Total number of records written to the shuffle by this task */ @volatile private var _shuffleRecordsWritten: Long = _ - def shuffleRecordsWritten = _shuffleRecordsWritten + def shuffleRecordsWritten: Long = _shuffleRecordsWritten private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 593a62b3e3b32..6cda7772f77bc 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -73,16 +73,16 @@ private[spark] abstract class StreamBasedRecordReader[T]( private var key = "" private var value: T = null.asInstanceOf[T] - override def initialize(split: InputSplit, context: TaskAttemptContext) = {} - override def close() = {} + override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} + override def close(): Unit = {} - override def getProgress = if (processed) 1.0f else 0.0f + override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey = key + override def getCurrentKey: String = key - override def getCurrentValue = value + override def getCurrentValue: T = value - override def nextKeyValue = { + override def nextKeyValue: Boolean = { if (!processed) { val fileIn = new PortableDataStream(split, context, index) value = parseStream(fileIn) @@ -119,7 +119,8 @@ private[spark] class StreamRecordReader( * The format for the PortableDataStream files */ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { - override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) + : CombineFileRecordReader[String, PortableDataStream] = { new CombineFileRecordReader[String, PortableDataStream]( split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) } @@ -204,7 +205,7 @@ class PortableDataStream( /** * Close the file (if it is currently open) */ - def close() = { + def close(): Unit = { if (isOpen) { try { fileIn.close() diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 21b782edd2a9e..87c2aa481095d 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -52,7 +52,7 @@ trait SparkHadoopMapRedUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 3340673f91156..cfd20392d12f1 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -45,7 +45,7 @@ trait SparkHadoopMapReduceUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 345db36630fd5..9150ad35712a1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} @@ -84,7 +85,7 @@ private[spark] class MetricsSystem private ( /** * Get any UI handlers used by this metrics system; can only be called after start(). */ - def getServletHandlers = { + def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") metricsServlet.map(_.getHandlers).getOrElse(Array()) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 2f65bc8b46609..0c2e212a33074 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -30,8 +30,12 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils._ -private[spark] class MetricsServlet(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class MetricsServlet( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -45,10 +49,12 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[ServletContextHandler]( - createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) - ) + def getHandlers: Array[ServletContextHandler] = { + Array[ServletContextHandler]( + createServletHandler(servletPath, + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + ) + } def getMetricsSnapshot(request: HttpServletRequest): String = { mapper.writeValueAsString(registry) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 0d83d8c425ca4..9fad4e7deacb6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { - def start: Unit - def stop: Unit + def start(): Unit + def stop(): Unit def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index a1a2c00ed1542..1ba25aa74aa02 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -32,11 +32,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - def apply(i: Int) = blockMessages(i) + def apply(i: Int): BlockMessage = blockMessages(i) - def iterator = blockMessages.iterator + def iterator: Iterator[BlockMessage] = blockMessages.iterator - def length = blockMessages.length + def length: Int = blockMessages.length def set(bufferMessage: BufferMessage) { val startTime = System.currentTimeMillis diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala index 3b245c5c7a4f3..9a9e22b0c2366 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala @@ -31,9 +31,9 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val initialSize = currentSize() var gotChunkForSendingOnce = false - def size = initialSize + def size: Int = initialSize - def currentSize() = { + def currentSize(): Int = { if (buffers == null || buffers.isEmpty) { 0 } else { @@ -100,11 +100,11 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: buffers.foreach(_.flip) } - def hasAckId() = (ackId != 0) + def hasAckId(): Boolean = ackId != 0 - def isCompletelyReceived() = !buffers(0).hasRemaining + def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - override def toString = { + override def toString: String = { if (hasAckId) { "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index c2d9578be7ebb..04eb2bf9ba4ab 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -101,9 +101,11 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, socketRemoteConnectionManagerId } - def key() = channel.keyFor(selector) + def key(): SelectionKey = channel.keyFor(selector) - def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + def getRemoteAddress(): InetSocketAddress = { + channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + } // Returns whether we have to register for further reads or not. def read(): Boolean = { @@ -280,7 +282,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, /* channel.socket.setSendBufferSize(256 * 1024) */ - override def getRemoteAddress() = address + override def getRemoteAddress(): InetSocketAddress = address val DEFAULT_INTEREST = SelectionKey.OP_READ diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala index 764dc5e5503ed..b3b281ff465f1 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala @@ -18,7 +18,9 @@ package org.apache.spark.network.nio private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + override def toString: String = { + connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + } } private[nio] object ConnectionId { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index ee22c6656e69e..741fe3e1ea750 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -188,7 +188,7 @@ private[nio] class ConnectionManager( private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private val selectorThread = new Thread("connection-manager-thread") { - override def run() = ConnectionManager.this.run() + override def run(): Unit = ConnectionManager.this.run() } selectorThread.setDaemon(true) // start this thread last, since it invokes run(), which accesses members above diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala index cbb37ec5ced1f..1cd13d887c6f6 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala @@ -26,7 +26,7 @@ private[nio] case class ConnectionManagerId(host: String, port: Int) { Utils.checkHost(host) assert (port > 0) - def toSocketAddress() = new InetSocketAddress(host, port) + def toSocketAddress(): InetSocketAddress = new InetSocketAddress(host, port) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index fb4a979b824c3..85d2fe2bf9c20 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -42,7 +42,9 @@ private[nio] abstract class Message(val typ: Long, val id: Int) { def timeTaken(): String = (finishTime - startTime).toString + " ms" - override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + override def toString: String = { + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + } } @@ -51,7 +53,7 @@ private[nio] object Message { var lastId = 1 - def getNewId() = synchronized { + def getNewId(): Int = synchronized { lastId += 1 if (lastId == 0) { lastId += 1 diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala index 278c5ac356ef2..a4568e849fa13 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer private[nio] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining + val size: Int = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { + lazy val buffers: ArrayBuffer[ByteBuffer] = { val ab = new ArrayBuffer[ByteBuffer]() ab += header.buffer if (buffer != null) { @@ -35,7 +35,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { ab } - override def toString = { + override def toString: String = { "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala index 6e20f291c5cec..7b3da4bb9d5ee 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala @@ -50,8 +50,10 @@ private[nio] class MessageChunkHeader( flip.asInstanceOf[ByteBuffer] } - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + override def toString: String = { + "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index cadd0c7ed19ba..53c4b32c95ab3 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -99,7 +99,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { case None => "(partial: " + initialValue + ")" } } - def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f) + def getFinalValueInternal(): Option[T] = PartialResult.this.getFinalValueInternal().map(f) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 1cbd684224b7c..9059eb13bb5d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -70,7 +70,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct } - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index b073eba8a1574..5117ccfabfcc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -186,7 +186,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: override val isEmpty = !it.hasNext // initializes/resets to start iterating from the beginning - def resetIterator() = { + def resetIterator(): Iterator[(String, Partition)] = { val iterators = (0 to 2).map( x => prev.partitions.iterator.flatMap(p => { if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None @@ -196,10 +196,10 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: } // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD - def hasNext(): Boolean = { !isEmpty } + override def hasNext: Boolean = { !isEmpty } // return the next preferredLocation of some partition of the RDD - def next(): (String, Partition) = { + override def next(): (String, Partition) = { if (it.hasNext) { it.next() } else { @@ -237,7 +237,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: val rotIt = new LocationIterator(prev) // deal with empty case, just create targetLen partition groups with no preferred location - if (!rotIt.hasNext()) { + if (!rotIt.hasNext) { (1 to targetLen).foreach(x => groupArr += PartitionGroup()) return } @@ -343,7 +343,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: private case class PartitionGroup(prefLoc: Option[String] = None) { var arr = mutable.ArrayBuffer[Partition]() - def size = arr.size + def size: Int = arr.size } private object PartitionGroup { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 486e86ce1bb19..f77abac42b623 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,8 +215,7 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes @@ -240,7 +239,7 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - override def getNext() = { + override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { @@ -337,11 +336,11 @@ private[spark] object HadoopRDD extends Logging { * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. */ - def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key) + def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key) + def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - def putCachedMetadata(key: String, value: Any) = + private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) /** Add Hadoop configuration specific to a single partition and attempt. */ @@ -371,7 +370,7 @@ private[spark] object HadoopRDD extends Logging { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[HadoopPartition] val inputSplit = partition.inputSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index e2267861e79df..0c28f045e46e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.sql.{Connection, ResultSet} +import java.sql.{PreparedStatement, Connection, ResultSet} import scala.reflect.ClassTag @@ -28,8 +28,9 @@ import org.apache.spark.util.NextIterator import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx + override def index: Int = idx } + // TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private /** * An RDD that executes an SQL query on a JDBC connection and reads results. @@ -70,7 +71,8 @@ class JdbcRDD[T: ClassTag]( }).toArray } - override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] + { context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() @@ -88,7 +90,7 @@ class JdbcRDD[T: ClassTag]( stmt.setLong(2, part.upper) val rs = stmt.executeQuery() - override def getNext: T = { + override def getNext(): T = { if (rs.next()) { mapRow(rs) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 4883fb828814c..a838aac6e8d1a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -31,6 +31,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7fb94840df99c..2ab967f4bb313 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -238,7 +238,7 @@ private[spark] object NewHadoopRDD { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[NewHadoopPartition] val inputSplit = partition.serializableHadoopSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index f12d0cffaba34..e2394e28f8d26 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -98,7 +98,7 @@ private[spark] class ParallelCollectionRDD[T: ClassTag]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = { + override def compute(s: Partition, context: TaskContext): Iterator[T] = { new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index f781a8d776f2a..a00f4c1cdff91 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -40,7 +40,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) } } @@ -59,8 +59,10 @@ class PartitionPruningRDD[T: ClassTag]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + firstParent[T].iterator( + split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + } override protected def getPartitions: Array[Partition] = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions @@ -74,7 +76,7 @@ object PartitionPruningRDD { * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD * when its type T is not known at compile time. */ - def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { + def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean): PartitionPruningRDD[T] = { new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index ed79032893d33..dc60d48927624 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -149,10 +149,10 @@ private[spark] class PipedRDD[T: ClassTag]( }.start() // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream).getLines + val lines = Source.fromInputStream(proc.getInputStream).getLines() new Iterator[String] { - def next() = lines.next() - def hasNext = { + def next(): String = lines.next() + def hasNext: Boolean = { if (lines.hasNext) { true } else { @@ -162,7 +162,7 @@ private[spark] class PipedRDD[T: ClassTag]( } // cleanup task working directory if used - if (workInTaskDirectory == true) { + if (workInTaskDirectory) { scala.util.control.Exception.ignoring(classOf[IOException]) { Utils.deleteRecursively(new File(taskDirectory)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a4c74ed03e330..ddbfd5624e741 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -186,7 +186,7 @@ abstract class RDD[T: ClassTag]( } /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel = storageLevel + def getStorageLevel: StorageLevel = storageLevel // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed @@ -746,13 +746,13 @@ abstract class RDD[T: ClassTag]( def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) => new Iterator[(T, U)] { - def hasNext = (thisIter.hasNext, otherIter.hasNext) match { + def hasNext: Boolean = (thisIter.hasNext, otherIter.hasNext) match { case (true, true) => true case (false, false) => false case _ => throw new SparkException("Can only zip RDDs with " + "same number of elements in each partition") } - def next = (thisIter.next, otherIter.next) + def next(): (T, U) = (thisIter.next(), otherIter.next()) } } } @@ -868,8 +868,8 @@ abstract class RDD[T: ClassTag]( // Our partitioner knows how to handle T (which, since we have a partitioner, is // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples val p2 = new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) + override def numPartitions: Int = p.numPartitions + override def getPartition(k: Any): Int = p.getPartition(k.asInstanceOf[(Any, _)]._1) } // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies // anyway, and when calling .keys, will not have a partitioner set, even though @@ -1394,7 +1394,7 @@ abstract class RDD[T: ClassTag]( } /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ - def context = sc + def context: SparkContext = sc /** * Private API for changing an RDD's ClassTag. diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index d9fe6847254fa..2dc47f95937cb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,14 +17,12 @@ package org.apache.spark.rdd -import scala.reflect.ClassTag - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { - override val index = idx + override val index: Int = idx override def hashCode(): Int = idx } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index ed24ea22a661c..c27f435eb9c5a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -105,7 +105,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { + def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index aece683ff3199..4239e7e22af89 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -44,7 +44,7 @@ private[spark] class UnionPartition[T: ClassTag]( var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) - def preferredLocations() = rdd.preferredLocations(parentPartition) + def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition) override val index: Int = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 95b2dd954e9f4..d0be304762e1f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -32,7 +32,7 @@ private[spark] class ZippedPartitionsPartition( override val index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues + def partitions: Seq[Partition] = partitionValues @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index fa83372bb4d11..e0edd7d4ae968 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -39,8 +39,11 @@ class AccumulableInfo ( } object AccumulableInfo { - def apply(id: Long, name: String, update: Option[String], value: String) = + def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { new AccumulableInfo(id, name, update, value) + } - def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) + def apply(id: Long, name: String, value: String): AccumulableInfo = { + new AccumulableInfo(id, name, None, value) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8feac6cb6b7a1..b405bd3338e7c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -946,7 +946,7 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { + def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 34fa6d27c3a45..c0d889360ae99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -149,47 +149,60 @@ private[spark] class EventLoggingListener( } // Events that do not trigger a flush - override def onStageSubmitted(event: SparkListenerStageSubmitted) = - logEvent(event) - override def onTaskStart(event: SparkListenerTaskStart) = - logEvent(event) - override def onTaskGettingResult(event: SparkListenerTaskGettingResult) = - logEvent(event) - override def onTaskEnd(event: SparkListenerTaskEnd) = - logEvent(event) - override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate) = - logEvent(event) + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = logEvent(event) + + override def onTaskStart(event: SparkListenerTaskStart): Unit = logEvent(event) + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = logEvent(event) + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = logEvent(event) + + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = logEvent(event) // Events that trigger a flush - override def onStageCompleted(event: SparkListenerStageCompleted) = - logEvent(event, flushLogger = true) - override def onJobStart(event: SparkListenerJobStart) = - logEvent(event, flushLogger = true) - override def onJobEnd(event: SparkListenerJobEnd) = + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { logEvent(event, flushLogger = true) - override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded) = + } + + override def onJobStart(event: SparkListenerJobStart): Unit = logEvent(event, flushLogger = true) + + override def onJobEnd(event: SparkListenerJobEnd): Unit = logEvent(event, flushLogger = true) + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { logEvent(event, flushLogger = true) - override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved) = + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { logEvent(event, flushLogger = true) - override def onUnpersistRDD(event: SparkListenerUnpersistRDD) = + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { logEvent(event, flushLogger = true) - override def onApplicationStart(event: SparkListenerApplicationStart) = + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { logEvent(event, flushLogger = true) - override def onApplicationEnd(event: SparkListenerApplicationEnd) = + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { logEvent(event, flushLogger = true) - override def onExecutorAdded(event: SparkListenerExecutorAdded) = + } + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { logEvent(event, flushLogger = true) - override def onExecutorRemoved(event: SparkListenerExecutorRemoved) = + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { logEvent(event, flushLogger = true) + } // No-op because logging every update would be overkill - override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. */ - def stop() = { + def stop(): Unit = { writer.foreach(_.close()) val target = new Path(logPath) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 8aa528ac573d0..e55b76c36cc5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -57,7 +57,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener private val stageIdToJobId = new HashMap[Int, Int] private val jobIdToStageIds = new HashMap[Int, Seq[Int]] private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } createLogDir() diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 29879b374b801..382b09422a4a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -34,7 +34,7 @@ private[spark] class JobWaiter[T]( @volatile private var _jobFinished = totalTasks == 0 - def jobFinished = _jobFinished + def jobFinished: Boolean = _jobFinished // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 759df023a6dcf..a3caa9f000c89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -160,7 +160,7 @@ private[spark] object OutputCommitCoordinator { class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) extends Actor with ActorLogReceive with Logging { - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) case StopCoordinator => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 4a9ff918afe25..e074ce6ebff0b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -64,5 +64,5 @@ private[spark] class ResultTask[T, U]( // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" + override def toString: String = "ResultTask(" + stageId + ", " + partitionId + ")" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 79709089c0da4..fd0d484b45460 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -47,7 +47,7 @@ private[spark] class ShuffleMapTask( /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index = 0 }, null) + this(0, null, new Partition { override def index: Int = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -83,5 +83,5 @@ private[spark] class ShuffleMapTask( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) + override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 52720d48ca67f..b711ff209af94 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -300,7 +300,7 @@ private[spark] object StatsReportListener extends Logging { } def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { - def f(d: Double) = format.format(d) + def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -346,7 +346,7 @@ private[spark] object StatsReportListener extends Logging { /** * Reformat a time interval in milliseconds to a prettier format for output */ - def millisToString(ms: Long) = { + def millisToString(ms: Long): String = { val (size, units) = if (ms > hours) { (ms.toDouble / hours, "hours") diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index cc13f57a49b89..4cbc6e84a6bdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -133,7 +133,7 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - override def toString = "Stage " + id + override def toString: String = "Stage " + id override def hashCode(): Int = id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 10c685f29d3ac..da07ce2c6ea49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -29,23 +29,22 @@ private[spark] sealed trait TaskLocation { /** * A location that includes both a host and an executor id on that host. */ -private [spark] case class ExecutorCacheTaskLocation(override val host: String, - val executorId: String) extends TaskLocation { -} +private [spark] +case class ExecutorCacheTaskLocation(override val host: String, executorId: String) + extends TaskLocation /** * A location on a host. */ private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation { - override def toString = host + override def toString: String = host } /** * A location on a host that is cached by HDFS. */ -private [spark] case class HDFSCacheTaskLocation(override val host: String) - extends TaskLocation { - override def toString = TaskLocation.inMemoryLocationTag + host +private [spark] case class HDFSCacheTaskLocation(override val host: String) extends TaskLocation { + override def toString: String = TaskLocation.inMemoryLocationTag + host } private[spark] object TaskLocation { @@ -54,14 +53,16 @@ private[spark] object TaskLocation { // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. val inMemoryLocationTag = "hdfs_cache_" - def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId) + def apply(host: String, executorId: String): TaskLocation = { + new ExecutorCacheTaskLocation(host, executorId) + } /** * Create a TaskLocation from a string returned by getPreferredLocations. * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the * location is cached. */ - def apply(str: String) = { + def apply(str: String): TaskLocation = { val hstr = str.stripPrefix(inMemoryLocationTag) if (hstr.equals(str)) { new HostTaskLocation(str) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f33fd4450b2a6..076b36e86c0ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -373,17 +373,17 @@ private[spark] class TaskSchedulerImpl( } def handleSuccessfulTask( - taskSetManager: TaskSetManager, - tid: Long, - taskResult: DirectTaskResult[_]) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskResult: DirectTaskResult[_]): Unit = synchronized { taskSetManager.handleSuccessfulTask(tid, taskResult) } def handleFailedTask( - taskSetManager: TaskSetManager, - tid: Long, - taskState: TaskState, - reason: TaskEndReason) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: TaskEndReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { // Need to revive offers again now that the task set manager state has been updated to @@ -423,7 +423,7 @@ private[spark] class TaskSchedulerImpl( starvationTimer.cancel() } - override def defaultParallelism() = backend.defaultParallelism() + override def defaultParallelism(): Int = backend.defaultParallelism() // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 529237f0d35dc..d509881c74fef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.Arrays +import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -29,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -97,7 +99,8 @@ private[spark] class TaskSetManager( var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] - override def runningTasks = runningTasksSet.size + + override def runningTasks: Int = runningTasksSet.size // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the @@ -168,9 +171,9 @@ private[spark] class TaskSetManager( var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level - override def schedulableQueue = null + override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null - override def schedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.NONE var emittedTaskSizeWarning = false @@ -585,7 +588,7 @@ private[spark] class TaskSetManager( /** * Marks the task as getting result and notifies the DAG Scheduler */ - def handleTaskGettingResult(tid: Long) = { + def handleTaskGettingResult(tid: Long): Unit = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) @@ -612,7 +615,7 @@ private[spark] class TaskSetManager( /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ - def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { + def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index info.markSuccessful() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 87ebf31139ce9..5d258d9da4d1a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -85,7 +85,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } - def receiveWithLogging = { + def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisterExecutor(executorId, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index f14aaeea0a25c..5a38ad9f2b12c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -109,7 +109,7 @@ private[spark] abstract class YarnSchedulerBackend( context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case RegisterClusterManager => logInfo(s"ApplicationMaster registered as $sender") amActor = Some(sender) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala index aa3ec0f8cfb9c..8df4f3b554c41 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala @@ -24,7 +24,7 @@ private[spark] object MemoryUtils { val OVERHEAD_FRACTION = 0.10 val OVERHEAD_MINIMUM = 384 - def calculateTotalMemory(sc: SparkContext) = { + def calculateTotalMemory(sc: SparkContext): Int = { sc.conf.getInt("spark.mesos.executor.memoryOverhead", math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 06bb527522141..b381436839227 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -387,7 +387,7 @@ private[spark] class MesosSchedulerBackend( } // TODO: query Mesos for number of cores - override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + override def defaultParallelism(): Int = sc.conf.getInt("spark.default.parallelism", 8) override def applicationId(): String = Option(appId).getOrElse { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index d95426d918e19..eb3f999b5b375 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -59,7 +59,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -117,7 +117,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: localActor ! ReviveOffers } - override def defaultParallelism() = + override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 1baa0e009f3ae..dfbde7c8a1b0d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -59,9 +59,10 @@ private[spark] class JavaSerializationStream( } private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { + extends DeserializationStream { + private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = Class.forName(desc.getName, false, loader) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index dc7aa99738c17..f83bcaa5cc09e 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -60,7 +60,7 @@ class KryoSerializer(conf: SparkConf) .split(',') .filter(!_.isEmpty) - def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 7de2f9cbb2866..660df00bc32f5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -106,7 +106,7 @@ class FileShuffleBlockManager(conf: SparkConf) * when the writers are closed successfully */ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, - writeMetrics: ShuffleWriteMetrics) = { + writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) @@ -268,7 +268,7 @@ object FileShuffleBlockManager { new PrimitiveVector[Long]() } - def apply(bucketId: Int) = files(bucketId) + def apply(bucketId: Int): File = files(bucketId) def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { assert(offsets.length == lengths.length) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index b292587d37028..87fd161e06c85 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -80,7 +80,7 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { * end of the output file. This will be used by getBlockLocation to figure out where each block * begins and ends. * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]) = { + def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) try { @@ -121,5 +121,5 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { } } - override def stop() = {} + override def stop(): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 1f012941c85ab..c186fd360fef6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -35,13 +35,13 @@ sealed abstract class BlockId { def name: String // convenience methods - def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None - def isRDD = isInstanceOf[RDDBlockId] - def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] + def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None + def isRDD: Boolean = isInstanceOf[RDDBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] - override def toString = name - override def hashCode = name.hashCode + override def toString: String = name + override def hashCode: Int = name.hashCode override def equals(other: Any): Boolean = other match { case o: BlockId => getClass == o.getClass && name.equals(o.name) case _ => false @@ -50,54 +50,54 @@ sealed abstract class BlockId { @DeveloperApi case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { - def name = "rdd_" + rddId + "_" + splitIndex + override def name: String = "rdd_" + rddId + "_" + splitIndex } // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" } @DeveloperApi case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { - def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) + override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } @DeveloperApi case class TaskResultBlockId(taskId: Long) extends BlockId { - def name = "taskresult_" + taskId + override def name: String = "taskresult_" + taskId } @DeveloperApi case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { - def name = "input-" + streamId + "-" + uniqueId + override def name: String = "input-" + streamId + "-" + uniqueId } /** Id associated with temporary local data managed as blocks. Not serializable. */ private[spark] case class TempLocalBlockId(id: UUID) extends BlockId { - def name = "temp_local_" + id + override def name: String = "temp_local_" + id } /** Id associated with temporary shuffle data managed as blocks. Not serializable. */ private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId { - def name = "temp_shuffle_" + id + override def name: String = "temp_shuffle_" + id } // Intended only for testing purposes private[spark] case class TestBlockId(id: String) extends BlockId { - def name = "test_" + id + override def name: String = "test_" + id } @DeveloperApi @@ -112,7 +112,7 @@ object BlockId { val TEST = "test_(.*)".r /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String) = id match { + def apply(id: String): BlockId = id match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index b177a59c721df..a6f1ebf325a7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -77,11 +77,11 @@ class BlockManagerId private ( @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = s"BlockManagerId($executorId, $host, $port)" + override def toString: String = s"BlockManagerId($executorId, $host, $port)" override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case id: BlockManagerId => executorId == id.executorId && port == id.port && host == id.host case _ => @@ -100,10 +100,10 @@ private[spark] object BlockManagerId { * @param port Port of the block manager. * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int) = + def apply(execId: String, host: String, port: Int): BlockManagerId = getCachedBlockManagerId(new BlockManagerId(execId, host, port)) - def apply(in: ObjectInput) = { + def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() obj.readExternal(in) getCachedBlockManagerId(obj) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 654796f23c96e..061964826f08b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -79,7 +79,7 @@ class BlockManagerMaster( * Check if block manager master has a block. Note that this can be used to check for only * those blocks that are reported to block manager master. */ - def contains(blockId: BlockId) = { + def contains(blockId: BlockId): Boolean = { !getLocations(blockId).isEmpty } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 787b0f96bec32..5b5328016124e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,7 +52,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -421,7 +421,7 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] - def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8462871e798a5..52fb896c4e21f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -38,7 +38,7 @@ class BlockManagerSlaveActor( import context.dispatcher // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 81164178b9e8e..f703e50b6b0ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -82,11 +82,13 @@ private[spark] class DiskBlockObjectWriter( { /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]) = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) - override def close() = out.close() - override def flush() = out.flush() + override def write(i: Int): Unit = callWithTiming(out.write(i)) + override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + callWithTiming(out.write(b, off, len)) + } + override def close(): Unit = out.close() + override def flush(): Unit = out.flush() } /** The file channel, used for repositioning / truncating the file. */ @@ -141,8 +143,9 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - def sync = fos.getFD.sync() - callWithTiming(sync) + callWithTiming { + fos.getFD.sync() + } } objOut.close() diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 132502b75f8cd..95e2d688d9b17 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,5 +24,7 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 120c327a7e580..0186eb30a1905 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -36,7 +36,7 @@ class RDDInfo( def isCached: Boolean = (memSize + diskSize + tachyonSize > 0) && numCachedPartitions > 0 - override def toString = { + override def toString: String = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + "MemorySize: %s; TachyonSize: %s; DiskSize: %s").format( @@ -44,7 +44,7 @@ class RDDInfo( bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize)) } - override def compare(that: RDDInfo) = { + override def compare(that: RDDInfo): Int = { this.id - that.id } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index e5e1cf5a69a19..134abea866218 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -50,11 +50,11 @@ class StorageLevel private( def this() = this(false, true, false, false) // For deserialization - def useDisk = _useDisk - def useMemory = _useMemory - def useOffHeap = _useOffHeap - def deserialized = _deserialized - def replication = _replication + def useDisk: Boolean = _useDisk + def useMemory: Boolean = _useMemory + def useOffHeap: Boolean = _useOffHeap + def deserialized: Boolean = _deserialized + def replication: Int = _replication assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") @@ -80,7 +80,7 @@ class StorageLevel private( false } - def isValid = (useMemory || useDisk || useOffHeap) && (replication > 0) + def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0) def toInt: Int = { var ret = 0 @@ -183,7 +183,7 @@ object StorageLevel { useMemory: Boolean, useOffHeap: Boolean, deserialized: Boolean, - replication: Int) = { + replication: Int): StorageLevel = { getCachedStorageLevel( new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication)) } @@ -197,7 +197,7 @@ object StorageLevel { useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, - replication: Int = 1) = { + replication: Int = 1): StorageLevel = { getCachedStorageLevel(new StorageLevel(useDisk, useMemory, false, deserialized, replication)) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index def49e80a3605..7d75929b96f75 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,7 +19,6 @@ package org.apache.spark.storage import scala.collection.mutable -import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -32,7 +31,7 @@ class StorageStatusListener extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() - def storageStatusList = executorIdToStorageStatus.values.toSeq + def storageStatusList: Seq[StorageStatus] = executorIdToStorageStatus.values.toSeq /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { @@ -56,7 +55,7 @@ class StorageStatusListener extends SparkListener { } } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { @@ -67,7 +66,7 @@ class StorageStatusListener extends SparkListener { } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { updateStorageStatus(unpersistRDD.rddId) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala index b86abbda1d3e7..65fa81704c365 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala @@ -24,5 +24,7 @@ import tachyon.client.TachyonFile * a length. */ private[spark] class TachyonFileSegment(val file: TachyonFile, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 0c24ad2760e08..adfa6bbada256 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -60,7 +60,7 @@ private[spark] class SparkUI private ( } initialize() - def getAppName = appName + def getAppName: String = appName /** Set the app name for this UI. */ def setAppName(name: String) { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index b5022fe853c49..f07864141a21c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -149,9 +149,11 @@ private[spark] object UIUtils extends Logging { } } - def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource + def prependBaseUri(basePath: String = "", resource: String = ""): String = { + uiRoot + basePath + resource + } - def commonHeaderNodes = { + def commonHeaderNodes: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index fc1844600f1cb..19ac7a826e306 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -51,7 +51,7 @@ private[spark] object UIWorkloadGenerator { val nJobSet = args(2).toInt val sc = new SparkContext(conf) - def setProperties(s: String) = { + def setProperties(s: String): Unit = { if(schedulingMode == SchedulingMode.FAIR) { sc.setLocalProperty("spark.scheduler.pool", s) } @@ -59,7 +59,7 @@ private[spark] object UIWorkloadGenerator { } val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = new Random().nextFloat() + def nextFloat(): Float = new Random().nextFloat() val jobs = Seq[(String, () => Long)]( ("Count", baseData.count), diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 3afd7ef07d7c9..69053fe44d7e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.ExceptionFailure import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ -import org.apache.spark.storage.StorageStatusListener +import org.apache.spark.storage.{StorageStatus, StorageStatusListener} import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { @@ -55,19 +55,19 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToShuffleWrite = HashMap[String, Long]() val executorToLogUrls = HashMap[String, Map[String, String]]() - def storageStatusList = storageStatusListener.storageStatusList + def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList - override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) = synchronized { + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap } - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 937d95a934b59..949e80d30f5eb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -73,7 +73,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Misc: val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() - def blockManagerIds = executorIdToBlockManagerId.values.toSeq + def blockManagerIds: Seq[BlockManagerId] = executorIdToBlockManagerId.values.toSeq var schedulingMode: Option[SchedulingMode] = None @@ -146,7 +146,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { val jobGroup = for ( props <- Option(jobStart.properties); group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -182,7 +182,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) @@ -219,7 +219,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { val stage = stageCompleted.stageInfo stageIdToInfo(stage.stageId) = stage val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { @@ -260,7 +260,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val stage = stageSubmitted.stageInfo activeStages(stage.stageId) = stage pendingStages.remove(stage.stageId) @@ -288,7 +288,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { @@ -312,7 +312,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // stageToTaskInfos already has the updated status. } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task // completion event is for. Let's just drop it here. This means we might have some speculation diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index b2bbfdee56946..7ffcf291b5cc6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -24,7 +24,7 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { val sc = parent.sc val killEnabled = parent.killEnabled - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) val listener = parent.jobProgressListener attachPage(new AllJobsPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 110f8780a9a12..e03442894c5cc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} +import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -170,7 +170,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") - def accumulableRow(acc: AccumulableInfo) = {acc.name}{acc.value} + def accumulableRow(acc: AccumulableInfo): Elem = + {acc.name}{acc.value} val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, accumulables.values.toSeq) @@ -293,10 +294,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantiles(data: Seq[Double]) = + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) = { + def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) + : Seq[Elem] = { val recordDist = getDistributionQuantiles(records).iterator getDistributionQuantiles(data).map(d => {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 937261de00e3a..1bd2d87e00796 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -32,10 +32,10 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" attachPage(new StagePage(this)) attachPage(new PoolPage(this)) - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) - def handleKillRequest(request: HttpServletRequest) = { - if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { + def handleKillRequest(request: HttpServletRequest): Unit = { + if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index dbf1ceeda1878..711a3697bda15 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -94,11 +94,11 @@ private[jobs] object UIData { var taskData = new HashMap[Long, TaskUIData] var executorSummary = new HashMap[String, ExecutorSummary] - def hasInput = inputBytes > 0 - def hasOutput = outputBytes > 0 - def hasShuffleRead = shuffleReadTotalBytes > 0 - def hasShuffleWrite = shuffleWriteBytes > 0 - def hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 + def hasInput: Boolean = inputBytes > 0 + def hasOutput: Boolean = outputBytes > 0 + def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 + def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 + def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0 } /** diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index a81291d505583..045bd784990d1 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -40,10 +40,10 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing - def storageStatusList = storageStatusListener.storageStatusList + def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList /** Filter RDD info to include only those with cached partitions */ - def rddInfoList = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq + def rddInfoList: Seq[RDDInfo] = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq /** Update the storage info of the RDDs whose blocks are among the given updated blocks */ private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { @@ -56,19 +56,19 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar * Assumes the storage status list is fully up-to-date. This implies the corresponding * StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener. */ - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val metrics = taskEnd.taskMetrics if (metrics != null && metrics.updatedBlocks.isDefined) { updateRDDInfo(metrics.updatedBlocks.get) } } - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val rddInfos = stageSubmitted.stageInfo.rddInfos rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { // Remove all partitions that are no longer cached in current completed stage val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet _rddInfoMap.retain { case (id, info) => @@ -76,7 +76,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { _rddInfoMap.remove(unpersistRDD.rddId) } } diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 390310243ee0a..9044aaeef2d48 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -27,8 +27,8 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat // scalastyle:on private[this] var completed = false - def next() = sub.next() - def hasNext = { + def next(): A = sub.next() + def hasNext: Boolean = { val r = sub.hasNext if (!r && !completed) { completed = true @@ -37,13 +37,13 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat r } - def completion() + def completion(): Unit } private[spark] object CompletionIterator { - def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { + def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A, I] = { new CompletionIterator[A,I](sub) { - def completion() = completionFunction + def completion(): Unit = completionFunction } } } diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index a465298c8c5ab..9aea8efa38c7a 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -57,7 +57,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va out.println } - def statCounter = StatCounter(data.slice(startIdx, endIdx)) + def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) /** * print a summary of this distribution to the given PrintStream. diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala index cf89c1782fd67..1718554061985 100644 --- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -39,31 +39,27 @@ private[spark] class ManualClock(private var time: Long) extends Clock { /** * @param timeToSet new time (in milliseconds) that the clock should represent */ - def setTime(timeToSet: Long) = - synchronized { - time = timeToSet - notifyAll() - } + def setTime(timeToSet: Long): Unit = synchronized { + time = timeToSet + notifyAll() + } /** * @param timeToAdd time (in milliseconds) to add to the clock's time */ - def advance(timeToAdd: Long) = - synchronized { - time += timeToAdd - notifyAll() - } + def advance(timeToAdd: Long): Unit = synchronized { + time += timeToAdd + notifyAll() + } /** * @param targetTime block until the clock time is set or advanced to at least this time * @return current time reported by the clock when waiting finishes */ - def waitTillTime(targetTime: Long): Long = - synchronized { - while (time < targetTime) { - wait(100) - } - getTimeMillis() + def waitTillTime(targetTime: Long): Long = synchronized { + while (time < targetTime) { + wait(100) } - + getTimeMillis() + } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index ac40f19ed6799..375ed430bde45 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -67,14 +67,15 @@ private[spark] object MetadataCleanerType extends Enumeration { type MetadataCleanerType = Value - def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = - "spark.cleaner.ttl." + which.toString + def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = { + "spark.cleaner.ttl." + which.toString + } } // TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. private[spark] object MetadataCleaner { - def getDelaySeconds(conf: SparkConf) = { + def getDelaySeconds(conf: SparkConf): Int = { conf.getInt("spark.cleaner.ttl", -1) } diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index 74fa77b68de0b..dad888548ed10 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -43,7 +43,7 @@ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef this } - override def toString = "(" + _1 + "," + _2 + ")" + override def toString: String = "(" + _1 + "," + _2 + ")" override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] } diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 6d8d9e8da3678..73d126ff6254e 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -22,7 +22,7 @@ package org.apache.spark.util */ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { - override def findClass(name: String) = { + override def findClass(name: String): Class[_] = { super.findClass(name) } diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala index 770ff9d5ad6ae..a06b6f84ef11b 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -27,7 +27,7 @@ import java.nio.channels.Channels */ private[spark] class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { - def value = buffer + def value: ByteBuffer = buffer private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { val length = in.readInt() diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index d80eed455c427..8586da1996cf3 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -141,8 +141,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { object StatCounter { /** Build a StatCounter from a list of values. */ - def apply(values: TraversableOnce[Double]) = new StatCounter(values) + def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values) /** Build a StatCounter from a list of values passed as variable-length arguments. */ - def apply(values: Double*) = new StatCounter(values) + def apply(values: Double*): StatCounter = new StatCounter(values) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index f5be5856c2109..310c0c109416c 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -82,7 +82,7 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo this } - override def update(key: A, value: B) = this += ((key, value)) + override def update(key: A, value: B): Unit = this += ((key, value)) override def apply(key: A): B = internalMap.apply(key) @@ -92,14 +92,14 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) + override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f) def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) def toMap: Map[A, B] = iterator.toMap /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ - def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime) /** Remove entries with values that are no longer strongly reachable. */ def clearNullValues() { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fa56bb09e4e5c..d9a671687aad0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -85,7 +85,7 @@ private[spark] object Utils extends Logging { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = Class.forName(desc.getName, false, loader) } ois.readObject.asInstanceOf[T] @@ -106,11 +106,10 @@ private[spark] object Utils extends Logging { /** Serialize via nested stream using specific serializer */ def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)( - f: SerializationStream => Unit) = { + f: SerializationStream => Unit): Unit = { val osWrapper = ser.serializeStream(new OutputStream { - def write(b: Int) = os.write(b) - - override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) + override def write(b: Int): Unit = os.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len) }) try { f(osWrapper) @@ -121,10 +120,9 @@ private[spark] object Utils extends Logging { /** Deserialize via nested stream using specific serializer */ def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)( - f: DeserializationStream => Unit) = { + f: DeserializationStream => Unit): Unit = { val isWrapper = ser.deserializeStream(new InputStream { - def read(): Int = is.read() - + override def read(): Int = is.read() override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) }) try { @@ -137,7 +135,7 @@ private[spark] object Utils extends Logging { /** * Get the ClassLoader which loaded Spark. */ - def getSparkClassLoader = getClass.getClassLoader + def getSparkClassLoader: ClassLoader = getClass.getClassLoader /** * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that @@ -146,7 +144,7 @@ private[spark] object Utils extends Logging { * This should be used whenever passing a ClassLoader to Class.ForName or finding the currently * active loader when setting up ClassLoader delegation chains. */ - def getContextOrSparkClassLoader = + def getContextOrSparkClassLoader: ClassLoader = Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader) /** Determines whether the provided class is loadable in the current thread. */ @@ -155,12 +153,14 @@ private[spark] object Utils extends Logging { } /** Preferred alternative to Class.forName(className) */ - def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) + def classForName(className: String): Class[_] = { + Class.forName(className, true, getContextOrSparkClassLoader) + } /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { + def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { @@ -1557,7 +1557,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ - def getFormattedClassName(obj: AnyRef) = { + def getFormattedClassName(obj: AnyRef): String = { obj.getClass.getSimpleName.replace("$", "") } @@ -1570,7 +1570,7 @@ private[spark] object Utils extends Logging { } /** Return an empty JSON object */ - def emptyJson = JObject(List[JField]()) + def emptyJson: JsonAST.JObject = JObject(List[JField]()) /** * Return a Hadoop FileSystem with the scheme encoded in the given path. @@ -1618,7 +1618,7 @@ private[spark] object Utils extends Logging { /** * Indicates whether Spark is currently running unit tests. */ - def isTesting = { + def isTesting: Boolean = { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index af1f64649f354..f79e8e0491ea1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -156,10 +156,10 @@ class BitSet(numBits: Int) extends Serializable { /** * Get an iterator over the set bits. */ - def iterator = new Iterator[Int] { + def iterator: Iterator[Int] = new Iterator[Int] { var ind = nextSetBit(0) override def hasNext: Boolean = ind >= 0 - override def next() = { + override def next(): Int = { val tmp = ind ind = nextSetBit(ind + 1) tmp diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8a0f5a602de12..9ff4744593d4d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -159,7 +159,7 @@ class ExternalAppendOnlyMap[K, V, C]( val batchSizes = new ArrayBuffer[Long] // Flush the disk writer's contents to disk, and update relevant variables - def flush() = { + def flush(): Unit = { val w = writer writer = null w.commitAndClose() @@ -355,7 +355,7 @@ class ExternalAppendOnlyMap[K, V, C]( val pairs: ArrayBuffer[(K, C)]) extends Comparable[StreamBuffer] { - def isEmpty = pairs.length == 0 + def isEmpty: Boolean = pairs.length == 0 // Invalid if there are no more pairs in this stream def minKeyHash: Int = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index d69f2d9048055..3262e670c2030 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -283,7 +283,7 @@ private[spark] class ExternalSorter[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables. // The writer is closed at the end of this process, and cannot be reused. - def flush() = { + def flush(): Unit = { val w = writer writer = null w.commitAndClose() diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index b8de4ff9aa494..c52591b352340 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -109,7 +109,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = -1 var nextPair: (K, V) = computeNextPair() @@ -132,9 +132,9 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4e363b74f4bef..c80057f95e0b2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -85,7 +85,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( protected var _bitset = new BitSet(_capacity) - def getBitSet = _bitset + def getBitSet: BitSet = _bitset // Init of the array in constructor (instead of in declaration) to work around a Scala compiler // specialization bug that would generate two arrays (one for Object and one for specialized T). @@ -183,7 +183,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( /** Return the value at the specified position. */ def getValue(pos: Int): T = _data(pos) - def iterator = new Iterator[T] { + def iterator: Iterator[T] = new Iterator[T] { var pos = nextPos(0) override def hasNext: Boolean = pos != INVALID_POS override def next(): T = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index 2e1ef06cbc4e1..61e22642761f0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -46,7 +46,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = _keySet.size + override def size: Int = _keySet.size /** Get the value for a given key */ def apply(k: K): V = { @@ -87,7 +87,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -103,9 +103,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index c5268c0fae0ef..bdbca00a00622 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -32,7 +32,7 @@ private[spark] object Utils { */ def takeOrdered[T](input: Iterator[T], num: Int)(implicit ord: Ordering[T]): Iterator[T] = { val ordering = new GuavaOrdering[T] { - override def compare(l: T, r: T) = ord.compare(l, r) + override def compare(l: T, r: T): Int = ord.compare(l, r) } collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator } diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 1d5467060623c..14b6ba4af489a 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -121,7 +121,7 @@ private[spark] object FileAppender extends Logging { val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT) val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT) - def createTimeBasedAppender() = { + def createTimeBasedAppender(): FileAppender = { val validatedParams: Option[(Long, String)] = rollingInterval match { case "daily" => logInfo(s"Rolling executor logs enabled for $file with daily rolling") @@ -149,7 +149,7 @@ private[spark] object FileAppender extends Logging { } } - def createSizeBasedAppender() = { + def createSizeBasedAppender(): FileAppender = { rollingSizeBytes match { case IntParam(bytes) => logInfo(s"Rolling executor logs enabled for $file with rolling every $bytes bytes") diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 76e7a2760bcd1..786b97ad7b9ec 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -105,7 +105,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals private val rng: Random = new XORShiftRandom - override def setSeed(seed: Long) = rng.setSeed(seed) + override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { if (ub - lb <= 0.0) { @@ -131,7 +131,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals def cloneComplement(): BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, !complement) - override def clone = new BernoulliCellSampler[T](lb, ub, complement) + override def clone: BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, complement) } @@ -153,7 +153,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T private val rng: Random = RandomSampler.newDefaultRNG - override def setSeed(seed: Long) = rng.setSeed(seed) + override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { @@ -167,7 +167,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T } } - override def clone = new BernoulliSampler[T](fraction) + override def clone: BernoulliSampler[T] = new BernoulliSampler[T](fraction) } @@ -209,7 +209,7 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] } } - override def clone = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) } @@ -228,15 +228,18 @@ class GapSamplingIterator[T: ClassTag]( val arrayClass = Array.empty[T].iterator.getClass val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass data.getClass match { - case `arrayClass` => ((n: Int) => { data = data.drop(n) }) - case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) - case _ => ((n: Int) => { + case `arrayClass` => + (n: Int) => { data = data.drop(n) } + case `arrayBufferClass` => + (n: Int) => { data = data.drop(n) } + case _ => + (n: Int) => { var j = 0 while (j < n && data.hasNext) { data.next() j += 1 } - }) + } } } @@ -244,21 +247,21 @@ class GapSamplingIterator[T: ClassTag]( override def next(): T = { val r = data.next() - advance + advance() r } private val lnq = math.log1p(-f) /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */ - private def advance: Unit = { + private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) val k = (math.log(u) / lnq).toInt iterDrop(k) } /** advance to first sample as part of object construction. */ - advance + advance() // Attempting to invoke this closer to the top with other object initialization // was causing it to break in strange ways, so I'm invoking it last, which seems to // work reliably. @@ -279,15 +282,18 @@ class GapSamplingReplacementIterator[T: ClassTag]( val arrayClass = Array.empty[T].iterator.getClass val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass data.getClass match { - case `arrayClass` => ((n: Int) => { data = data.drop(n) }) - case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) - case _ => ((n: Int) => { + case `arrayClass` => + (n: Int) => { data = data.drop(n) } + case `arrayBufferClass` => + (n: Int) => { data = data.drop(n) } + case _ => + (n: Int) => { var j = 0 while (j < n && data.hasNext) { data.next() j += 1 } - }) + } } } @@ -300,7 +306,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( override def next(): T = { val r = v rep -= 1 - if (rep <= 0) advance + if (rep <= 0) advance() r } @@ -309,7 +315,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is * q is the probabililty of Poisson(0; f) */ - private def advance: Unit = { + private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) val k = (math.log(u) / (-f)).toInt iterDrop(k) @@ -343,7 +349,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( } /** advance to first sample as part of object construction. */ - advance + advance() // Attempting to invoke this closer to the top with other object initialization // was causing it to break in strange ways, so I'm invoking it last, which seems to // work reliably. diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 2ae308dacf1ae..9e29bf9d61f17 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -311,7 +311,7 @@ private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: var acceptBound: Double = Double.NaN // upper bound for accepting item instantly var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist - def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN + def areBoundsEmpty: Boolean = acceptBound.isNaN || waitListBound.isNaN def merge(other: Option[AcceptanceResult]): Unit = { if (other.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 467b890fb4bb9..c4a7b4441c85c 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -83,7 +83,7 @@ private[spark] object XORShiftRandom { * @return Map of execution times for {@link java.util.Random java.util.Random} * and XORShift */ - def benchmark(numIters: Int) = { + def benchmark(numIters: Int): Map[String, Long] = { val seed = 1L val million = 1e6.toInt diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 328d59485a731..56f5dbe53fad4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,7 +44,13 @@ object MimaExcludes { // the maven-generated artifacts in 1.3. excludePackage("org.spark-project.jetty"), MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.rdd.JdbcRDD.compute"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast") ) case v if v.startsWith("1.3") => diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 0ff521706c71a..459a5035d4984 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -137,9 +137,9 @@ - + - + From 1afcf773d0cafdfd9bf106fdc5c429ed2ba3dd36 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 24 Mar 2015 01:12:11 -0700 Subject: [PATCH 002/171] [SPARK-6452] [SQL] Checks for missing attributes and unresolved operator for all types of operator In `CheckAnalysis`, `Filter` and `Aggregate` are checked in separate case clauses, thus never hit those clauses for unresolved operators and missing input attributes. This PR also removes the `prettyString` call when generating error message for missing input attributes. Because result of `prettyString` doesn't contain expression ID, and may give confusing messages like > resolved attributes a missing from a cc rxin [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5129) Author: Cheng Lian Closes #5129 from liancheng/spark-6452 and squashes the following commits: 52cdc69 [Cheng Lian] Addresses comments 029f9bd [Cheng Lian] Checks for missing attributes and unresolved operator for all types of operator --- .../sql/catalyst/analysis/CheckAnalysis.scala | 15 ++++++++++----- .../spark/sql/catalyst/plans/QueryPlan.scala | 7 +++++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 ++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fb975ee5e7296..425e1e41cbf21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -63,7 +63,7 @@ class CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) => + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK case e: Attribute if !groupingExprs.contains(e) => @@ -85,13 +85,18 @@ class CheckAnalysis { cleaned.foreach(checkValidAggregateExpression) + case _ => // Fallbacks to the following checks + } + + operator match { case o if o.children.nonEmpty && o.missingInput.nonEmpty => - val missingAttributes = o.missingInput.map(_.prettyString).mkString(",") - val input = o.inputSet.map(_.prettyString).mkString(",") + val missingAttributes = o.missingInput.mkString(",") + val input = o.inputSet.mkString(",") - failAnalysis(s"resolved attributes $missingAttributes missing from $input") + failAnalysis( + s"resolved attribute(s) $missingAttributes missing from $input " + + s"in operator ${operator.simpleString}") - // Catch all case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 400a6b2825c10..48191f31198f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -47,9 +47,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. + * + * Note that virtual columns should be excluded. Currently, we only support the grouping ID + * virtual column. */ - def missingInput: AttributeSet = (references -- inputSet) - .filter(_.name != VirtualColumn.groupingIdName) + def missingInput: AttributeSet = + (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c1dd5aa913ddc..359aec4a7b5ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -199,4 +199,22 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } + + test("SPARK-6452 regression test") { + // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + val plan = + Aggregate( + Nil, + Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + LocalRelation( + AttributeReference("a", StringType)(exprId = ExprId(2)))) + + assert(plan.resolved) + + val message = intercept[AnalysisException] { + caseSensitiveAnalyze(plan) + }.getMessage + + assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + } } From 37fac1dcd2b0f2845d2952a417fcf85d40351f57 Mon Sep 17 00:00:00 2001 From: Brennon York Date: Tue, 24 Mar 2015 10:33:04 +0000 Subject: [PATCH 003/171] [SPARK-6477][Build]: Run MIMA tests before the Spark test suite This moves the MIMA checks to before the full Spark test suite such that, if new PR's fail the MIMA check, they will return much faster having not run the entire test suite. This is preferable to the current scenario where a user would have to wait until the entire test suite completes before realizing it failed on a MIMA check in which case, once the MIMA issues are fixed, the user would have to resubmit and rerun the full test suite again. Author: Brennon York Closes #5145 from brennonyork/SPARK-6477 and squashes the following commits: 12b0aee [Brennon York] updated to put the mima checks before the spark test suite --- dev/run-tests | 18 +++++++++--------- dev/run-tests-codes.sh | 6 +++--- dev/run-tests-jenkins | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index d6935a61c6d29..561d7fc9e7b1f 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -178,6 +178,15 @@ CURRENT_BLOCK=$BLOCK_BUILD fi } +echo "" +echo "=========================================================================" +echo "Detecting binary incompatibilities with MiMa" +echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_MIMA + +./dev/mima + echo "" echo "=========================================================================" echo "Running Spark unit tests" @@ -227,12 +236,3 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS ./python/run-tests - -echo "" -echo "=========================================================================" -echo "Detecting binary incompatibilities with MiMa" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_MIMA - -./dev/mima diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 1348e0609dda4..8ab6db6925d6e 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -22,6 +22,6 @@ readonly BLOCK_RAT=11 readonly BLOCK_SCALA_STYLE=12 readonly BLOCK_PYTHON_STYLE=13 readonly BLOCK_BUILD=14 -readonly BLOCK_SPARK_UNIT_TESTS=15 -readonly BLOCK_PYSPARK_UNIT_TESTS=16 -readonly BLOCK_MIMA=17 +readonly BLOCK_MIMA=15 +readonly BLOCK_SPARK_UNIT_TESTS=16 +readonly BLOCK_PYSPARK_UNIT_TESTS=17 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 5f4000e83925c..3a937b637e003 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -199,12 +199,12 @@ done failing_test="Python style tests" elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then failing_test="to build" + elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then + failing_test="MiMa tests" elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then failing_test="Spark unit tests" elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then failing_test="PySpark unit tests" - elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then - failing_test="MiMa tests" else failing_test="some tests" fi From c12312f8b16bb8f9355d5f9e786c5a608863eb01 Mon Sep 17 00:00:00 2001 From: Cong Yue Date: Tue, 24 Mar 2015 12:56:13 +0000 Subject: [PATCH 004/171] Update the command to use IPython notebook As for "notebook --pylab inline" is not supported any more, update the related documentation for this. Author: Cong Yue Closes #5111 from yuecong/patch-1 and squashes the following commits: 872df76 [Cong Yue] Update the command to use IPython notebook --- docs/programming-guide.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 5fe832b6fa100..f5b775da7930a 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -237,9 +237,13 @@ You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +your notebook before you start to try Spark from the IPython notebook. + From b293afc42c370da0ff308c16ba4af81289ea6f89 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Mar 2015 13:48:33 +0000 Subject: [PATCH 005/171] [SPARK-6473] [core] Do not try to figure out Scala version if not needed... .... Author: Marcelo Vanzin Closes #5143 from vanzin/SPARK-6473 and squashes the following commits: a2e5e2d [Marcelo Vanzin] [SPARK-6473] [core] Do not try to figure out Scala version if not needed. --- .../org/apache/spark/launcher/AbstractCommandBuilder.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index dc90e9e987234..2da5f7278729e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -147,7 +147,6 @@ void addOptionString(List cmd, String options) { */ List buildClassPath(String appClassPath) throws IOException { String sparkHome = getSparkHome(); - String scala = getScalaVersion(); List cp = new ArrayList(); addToClassPath(cp, getenv("SPARK_CLASSPATH")); @@ -158,6 +157,7 @@ List buildClassPath(String appClassPath) throws IOException { boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES")); boolean isTesting = "1".equals(getenv("SPARK_TESTING")); if (prependClasses || isTesting) { + String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", "yarn", "launcher"); @@ -182,7 +182,7 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - String assembly = findAssembly(scala); + String assembly = findAssembly(); addToClassPath(cp, assembly); // When Hive support is needed, Datanucleus jars must be included on the classpath. Datanucleus @@ -330,7 +330,7 @@ String getenv(String key) { return firstNonEmpty(childEnv.get(key), System.getenv(key)); } - private String findAssembly(String scalaVersion) { + private String findAssembly() { String sparkHome = getSparkHome(); File libdir; if (new File(sparkHome, "RELEASE").isFile()) { @@ -338,7 +338,7 @@ private String findAssembly(String scalaVersion) { checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", libdir.getAbsolutePath()); } else { - libdir = new File(sparkHome, String.format("assembly/target/scala-%s", scalaVersion)); + libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); } final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); From 85cf0636825d1997d64d0bdc04618f29b7222da1 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 24 Mar 2015 16:13:25 +0000 Subject: [PATCH 006/171] [SPARK-5559] [Streaming] [Test] Remove oppotunity we met flakiness when running FlumeStreamSuite When we run FlumeStreamSuite on Jenkins, sometimes we get error like as follows. sbt.ForkMain$ForkError: The code passed to eventually never returned normally. Attempted 52 times over 10.094849836 seconds. Last failure message: Error connecting to localhost/127.0.0.1:23456. at org.scalatest.concurrent.Eventually$class.tryTryAgain$1(Eventually.scala:420) at org.scalatest.concurrent.Eventually$class.eventually(Eventually.scala:438) at org.scalatest.concurrent.Eventually$.eventually(Eventually.scala:478) at org.scalatest.concurrent.Eventually$class.eventually(Eventually.scala:307) at org.scalatest.concurrent.Eventually$.eventually(Eventually.scala:478) at org.apache.spark.streaming.flume.FlumeStreamSuite.writeAndVerify(FlumeStreamSuite.scala:116) at org.apache.spark.streaming.flume.FlumeStreamSuite.org$apache$spark$streaming$flume$FlumeStreamSuite$$testFlumeStream(FlumeStreamSuite.scala:74) at org.apache.spark.streaming.flume.FlumeStreamSuite$$anonfun$3.apply$mcV$sp(FlumeStreamSuite.scala:66) at org.apache.spark.streaming.flume.FlumeStreamSuite$$anonfun$3.apply(FlumeStreamSuite.scala:66) at org.apache.spark.streaming.flume.FlumeStreamSuite$$anonfun$3.apply(FlumeStreamSuite.scala:66) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.scalatest.Suite$class.withFixture(Suite.scala:1122) at org.scalatest.FunSuite.withFixture(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) This error is caused by check-then-act logic when it find free-port . /** Find a free port */ private def findFreePort(): Int = { Utils.startServiceOnPort(23456, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) }, conf)._2 } Removing the check-then-act is not easy but we can reduce the chance of having the error by choosing random value for initial port instead of 23456. Author: Kousuke Saruta Closes #4337 from sarutak/SPARK-5559 and squashes the following commits: 16f109f [Kousuke Saruta] Added `require` to Utils#startServiceOnPort c39d8b6 [Kousuke Saruta] Merge branch 'SPARK-5559' of github.com:sarutak/spark into SPARK-5559 1610ba2 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5559 33357e3 [Kousuke Saruta] Changed "findFreePort" method in MQTTStreamSuite and FlumeStreamSuite so that it can choose valid random port a9029fe [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5559 9489ef9 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5559 8212e42 [Kousuke Saruta] Modified default port used in FlumeStreamSuite from 23456 to random value --- core/src/main/scala/org/apache/spark/util/Utils.scala | 4 ++++ .../org/apache/spark/streaming/flume/FlumeStreamSuite.scala | 5 +++-- .../org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d9a671687aad0..0b5a914e7dbbf 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1876,6 +1876,10 @@ private[spark] object Utils extends Logging { startService: Int => (T, Int), conf: SparkConf, serviceName: String = ""): (T, Int) = { + + require(startPort == 0 || (1024 <= startPort && startPort < 65536), + "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.") + val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" val maxRetries = portMaxRetries(conf) for (offset <- 0 to maxRetries) { diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 322de7bf2fed8..51d273af8da84 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -28,6 +28,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils import org.apache.flume.source.avro import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline @@ -40,7 +41,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} import org.apache.spark.util.Utils class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { @@ -76,7 +76,8 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L /** Find a free port */ private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 0f3298af6234a..24d78ecb3a97d 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -25,6 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.activemq.broker.{TransportConnector, BrokerService} +import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence @@ -113,7 +114,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { } private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) From 08d452801195cc6cf0697a594e98cd4778f358ee Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Tue, 24 Mar 2015 16:33:38 +0000 Subject: [PATCH 007/171] [ML][docs][minor] Define LabeledDocument/Document classes in CV example To easier copy/paste Cross-Validation example code snippet need to define LabeledDocument/Document in it, since they difined in a previous example. Author: Peter Rudenko Closes #5135 from petro-rudenko/patch-3 and squashes the following commits: 5190c75 [Peter Rudenko] Fix primitive types for java examples. 1d35383 [Peter Rudenko] [SQL][docs][minor] Define LabeledDocument/Document classes in CV example --- docs/ml-guide.md | 51 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index da6aef7f14c4c..c08c76d226713 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -408,31 +408,31 @@ import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. public class Document implements Serializable { - private Long id; + private long id; private String text; - public Document(Long id, String text) { + public Document(long id, String text) { this.id = id; this.text = text; } - public Long getId() { return this.id; } - public void setId(Long id) { this.id = id; } + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } public String getText() { return this.text; } public void setText(String text) { this.text = text; } } public class LabeledDocument extends Document implements Serializable { - private Double label; + private double label; - public LabeledDocument(Long id, String text, Double label) { + public LabeledDocument(long id, String text, double label) { super(id, text); this.label = label; } - public Double getLabel() { return this.label; } - public void setLabel(Double label) { this.label = label; } + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } } // Set up contexts. @@ -565,6 +565,11 @@ import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from case classes. +case class LabeledDocument(id: Long, text: String, label: Double) +case class Document(id: Long, text: String) + val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) @@ -655,6 +660,36 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} + SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); From a1d1529daebee30b76b954d16a30849407f795d1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 24 Mar 2015 10:11:27 -0700 Subject: [PATCH 008/171] [SPARK-6475][SQL] recognize array types when infer data types from JavaBeans Right now if there is a array field in a JavaBean, the user wold see an exception in `createDataFrame`. liancheng Author: Xiangrui Meng Closes #5146 from mengxr/SPARK-6475 and squashes the following commits: 51e87e5 [Xiangrui Meng] validate schemas 4f2df5e [Xiangrui Meng] recognize array types when infer data types from JavaBeans --- .../org/apache/spark/sql/SQLContext.scala | 80 ++++++++++++------- .../apache/spark/sql/JavaDataFrameSuite.java | 41 +++++++++- 2 files changed, 89 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dc9912b52dcab..e59cf9b9e037b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { + val (dataType, _) = inferDataType(beanClass) + dataType.asInstanceOf[StructType].fields.map { f => + AttributeReference(f.name, f.dataType, f.nullable)() + } + } + + /** + * Infers the corresponding SQL data type of a Java class. + * @param clazz Java class + * @return (SQL data type, nullable) + */ + private def inferDataType(clazz: Class[_]): (DataType, Boolean) = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. - val beanInfo = Introspector.getBeanInfo(beanClass) - - // Note: The ordering of elements may differ from when the schema is inferred in Scala. - // This is because beanInfo.getPropertyDescriptors gives no guarantees about - // element ordering. - val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - fields.map { property => - val (dataType, nullable) = property.getPropertyType match { - case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) - case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) - case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) - } - AttributeReference(property.getName, dataType, nullable)() + clazz match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + + case c: Class[_] if c.isArray => + val (dataType, nullable) = inferDataType(c.getComponentType) + (ArrayType(dataType, nullable), true) + + case _ => + val beanInfo = Introspector.getBeanInfo(clazz) + val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val fields = properties.map { property => + val (dataType, nullable) = inferDataType(property.getPropertyType) + new StructField(property.getName, dataType, nullable) + } + (new StructType(fields), true) } } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 2d586f784ac5a..1ff2d5a190521 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,29 +17,39 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.test.TestSQLContext$; -import static org.apache.spark.sql.functions.*; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { + private transient JavaSparkContext jsc; private transient SQLContext context; @Before public void setUp() { // Trigger static initializer of TestData TestData$.MODULE$.testData(); + jsc = new JavaSparkContext(TestSQLContext.sparkContext()); context = TestSQLContext$.MODULE$; } @After public void tearDown() { + jsc = null; context = null; } @@ -90,4 +100,33 @@ public void testShow() { df.show(); df.show(1000); } + + public static class Bean implements Serializable { + private double a = 0.0; + private Integer[] b = new Integer[]{0, 1}; + + public double getA() { + return a; + } + + public Integer[] getB() { + return b; + } + } + + @Test + public void testCreateDataFrameFromJavaBeans() { + Bean bean = new Bean(); + JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); + DataFrame df = context.createDataFrame(rdd, Bean.class); + StructType schema = df.schema(); + Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), + schema.apply("a")); + Assert.assertEquals( + new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), + schema.apply("b")); + Row first = df.select("a", "b").first(); + Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); + Assert.assertArrayEquals(bean.getB(), first.getAs(1)); + } } From 6bdddb6f6ffbd1bee4c45880904e9561b18764a7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 24 Mar 2015 12:08:19 -0700 Subject: [PATCH 009/171] [SPARK-6361][SQL] support adding a column with metadata in DF This is used by ML pipelines to embed ML attributes in columns created by ML transformers/estimators. marmbrus Author: Xiangrui Meng Closes #5151 from mengxr/SPARK-6361 and squashes the following commits: bb30de3 [Xiangrui Meng] support adding a column with metadata in DF --- .../expressions/namedExpressions.scala | 20 ++++++++++++------- .../scala/org/apache/spark/sql/Column.scala | 13 ++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 15 +++++++++++--- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 17f7f9fe51376..3dd7d38847b44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -95,9 +95,12 @@ abstract class Attribute extends NamedExpression { * @param name the name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. + * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. */ -case class Alias(child: Expression, name: String) - (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) +case class Alias(child: Expression, name: String)( + val exprId: ExprId = NamedExpression.newExprId, + val qualifiers: Seq[String] = Nil, + val explicitMetadata: Option[Metadata] = None) extends NamedExpression with trees.UnaryNode[Expression] { override type EvaluatedType = Any @@ -107,9 +110,11 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable override def metadata: Metadata = { - child match { - case named: NamedExpression => named.metadata - case _ => Metadata.empty + explicitMetadata.getOrElse { + child match { + case named: NamedExpression => named.metadata + case _ => Metadata.empty + } } } @@ -123,11 +128,12 @@ case class Alias(child: Expression, name: String) override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" - override protected final def otherCopyArgs = exprId :: qualifiers :: Nil + override protected final def otherCopyArgs = exprId :: qualifiers :: explicitMetadata :: Nil override def equals(other: Any): Boolean = other match { case a: Alias => - name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers + name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers && + explicitMetadata == a.explicitMetadata case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ec7d15f5bc4e7..2ae47e07d45ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -594,6 +594,19 @@ class Column(protected[sql] val expr: Expression) { */ def as(alias: Symbol): Column = Alias(expr, alias.name)() + /** + * Gives the column an alias with metadata. + * {{{ + * val metadata: Metadata = ... + * df.select($"colA".as("colB", metadata)) + * }}} + * + * @group expr_ops + */ + def as(alias: String, metadata: Metadata): Column = { + Alias(expr, alias)(explicitMetadata = Some(metadata)) + } + /** * Casts the column to a different data type. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a53ae97d6243a..bc8fae100db6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.NamedExpression -import org.apache.spark.sql.catalyst.plans.logical.{Project, NoRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { @@ -322,4 +320,15 @@ class ColumnExpressionSuite extends QueryTest { assert('key.desc == 'key.desc) assert('key.desc != 'key.asc) } + + test("alias with metadata") { + val metadata = new MetadataBuilder() + .putString("originName", "value") + .build() + val schema = testData + .select($"*", col("value").as("abc", metadata)) + .schema + assert(schema("value").metadata === Metadata.empty) + assert(schema("abc").metadata === metadata) + } } From 32efadd0500f10bddf2ae8456c9e719ec52940f1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Mar 2015 12:09:02 -0700 Subject: [PATCH 010/171] [SPARK-6459][SQL] Warn when constructing trivially true equals predicate For example, one might expect the following code to work, but it does not. Now you will at least get a warning with a suggestion to use aliases. ```scala val df = sqlContext.load(path, "parquet") val txns = df.groupBy("cust_id").agg($"cust_id", countDistinct($"day_num").as("txns")) val spend = df.groupBy("cust_id").agg($"cust_id", sum($"extended_price").as("spend")) val rmJoin = txns.join(spend, txns("cust_id") === spend("cust_id"), "inner") ``` Author: Michael Armbrust Closes #5163 from marmbrus/selfJoinError and squashes the following commits: 16c1f0b [Michael Armbrust] fix visibility 1b57e8d [Michael Armbrust] Warn when constructing trivially true equals predicate --- .../main/scala/org/apache/spark/sql/Column.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 2ae47e07d45ec..3cd7adf8cab5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.language.implicitConversions import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField} @@ -46,7 +47,7 @@ private[sql] object Column { * @groupname Ungrouped Support functions for DataFrames. */ @Experimental -class Column(protected[sql] val expr: Expression) { +class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) @@ -109,7 +110,15 @@ class Column(protected[sql] val expr: Expression) { * * @group expr_ops */ - def === (other: Any): Column = EqualTo(expr, lit(other).expr) + def === (other: Any): Column = { + val right = lit(other).expr + if (this.expr == right) { + logWarning( + s"Constructing trivially true equals predicate, '${this.expr} = $right'. " + + "Perhaps you need to use aliases.") + } + EqualTo(expr, right) + } /** * Equality test. From 26c6ce3d2947df5a294b1ad4a22fae5d31d06c19 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Mar 2015 12:10:30 -0700 Subject: [PATCH 011/171] [SPARK-6437][SQL] Use completion iterator to close external sorter Otherwise we will leak files when spilling occurs. Author: Michael Armbrust Closes #5161 from marmbrus/cleanupAfterSort and squashes the following commits: cb13d3c [Michael Armbrust] hint to inferencer cdebdf5 [Michael Armbrust] Use completion iterator to close external sorter --- .../org/apache/spark/sql/execution/basicOperators.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 20c9bc3e75542..1f5251a20376f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.MutablePair +import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.util.collection.ExternalSorter /** @@ -194,7 +194,9 @@ case class ExternalSort( val ordering = newOrdering(sortOrder, child.output) val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r, null))) - sorter.iterator.map(_._1) + val baseIterator = sorter.iterator.map(_._1) + // TODO(marmbrus): The complex type signature below thwarts inference for no reason. + CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop()) }, preservesPartitioning = true) } From 3fa3d121dfec60f9768d3859e8450ee482b2d4e8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Mar 2015 12:28:01 -0700 Subject: [PATCH 012/171] [SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes Due to a recent change that made `StructType` a `Seq` we started inadvertently turning `StructType`s into generic `Traversable` when attempting nested tree transformations. In this PR we explicitly avoid descending into `DataType`s to avoid this bug. Author: Michael Armbrust Closes #5157 from marmbrus/udfFix and squashes the following commits: 26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold StructTypes --- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 ++ .../spark/sql/catalyst/trees/TreeNode.scala | 20 ++++++++++++++++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 6 ++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 48191f31198f3..bd9291e9ba5d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) case other => other @@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f84ffe4e176cc..0ae9f6b2965d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.types.DataType /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) @@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) @@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + val defaultCtor = + getClass.getConstructors + .find(_.getParameterTypes.size != 0) + .headOption + .getOrElse(sys.error(s"No valid constructor for $nodeName")) + try { CurrentOrigin.withOrigin(origin) { // Skip no-arg constructors that are just there for kryo. - val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head if (otherCopyArgs.isEmpty) { defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] } else { @@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " - + s"Exception message: ${e.getMessage}.") + this, + s""" + |Failed to copy node. + |Is otherCopyArgs specified correctly for $nodeName. + |Exception message: ${e.getMessage} + |ctor: $defaultCtor? + |args: ${newArgs.mkString(", ")} + """.stripMargin) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index be105c6e83594..d615542ab50a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -50,4 +50,10 @@ class UDFSuite extends QueryTest { .select($"ret.f1").head().getString(0) assert(result === "test") } + + test("udf that is transformed") { + udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + // 1 + 1 is constant folded causing a transformation. + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + } } From 046c1e2aa459147bf592371bb9fb7a65edb182e7 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Mar 2015 13:22:46 -0700 Subject: [PATCH 013/171] [SPARK-6375][SQL] Fix formatting of error messages. Author: Michael Armbrust Closes #5155 from marmbrus/errorMessages and squashes the following commits: b898188 [Michael Armbrust] Fix formatting of error messages. --- .../apache/spark/sql/AnalysisException.scala | 6 +++++ .../sql/catalyst/analysis/Analyzer.scala | 5 ++-- .../spark/sql/catalyst/analysis/package.scala | 8 ++++++ .../sql/catalyst/analysis/unresolved.scala | 4 +++ .../expressions/namedExpressions.scala | 7 ++++++ .../catalyst/plans/logical/LogicalPlan.scala | 3 ++- .../spark/sql/hive/ErrorPositionSuite.scala | 25 +++++++++++++++++-- 7 files changed, 53 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 15add84878ecf..34fedead44db3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -30,6 +30,12 @@ class AnalysisException protected[sql] ( val startPosition: Option[Int] = None) extends Exception with Serializable { + def withPosition(line: Option[Int], startPosition: Option[Int]) = { + val newException = new AnalysisException(message, line, startPosition) + newException.setStackTrace(getStackTrace) + newException + } + override def getMessage: String = { val lineAnnotation = line.map(l => s" line $l").getOrElse("") val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 92d3db077c5e1..c93af79795bc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -175,7 +175,7 @@ class Analyzer(catalog: Catalog, catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { case _: NoSuchTableException => - u.failAnalysis(s"no such table ${u.tableIdentifier}") + u.failAnalysis(s"no such table ${u.tableName}") } } @@ -275,7 +275,8 @@ class Analyzer(catalog: Catalog, q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolveChildren(name, resolver).getOrElse(u) + val result = + withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index e95f19e69ed43..a7d3a8ee7deb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -42,4 +42,12 @@ package object analysis { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } } + + /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ + def withPosition[A](t: TreeNode[_])(f: => A) = { + try f catch { + case a: AnalysisException => + throw a.withPosition(t.origin.line, t.origin.startPosition) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a7cd4124e56f3..ad5172c0349eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -36,6 +36,10 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str case class UnresolvedRelation( tableIdentifier: Seq[String], alias: Option[String] = None) extends LeafNode { + + /** Returns a `.` separated name for this relation. */ + def tableName = tableIdentifier.mkString(".") + override def output = Nil override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3dd7d38847b44..08361d043b6ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -41,6 +41,13 @@ abstract class NamedExpression extends Expression { def name: String def exprId: ExprId + /** + * Returns a dot separated fully qualified name for this attribute. Given that there can be + * multiple qualifiers, it is possible that there are other possible way to refer to this + * attribute. + */ + def qualifiedName: String = (qualifiers.headOption.toSeq :+ name).mkString(".") + /** * All possible qualifiers for the expression. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8c4f09b58a4f2..0f8b144ccc113 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -208,8 +208,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // More than one match. case ambiguousReferences => + val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") throw new AnalysisException( - s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") + s"Reference '$name' is ambiguous, could be: $referenceNames.") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index f04437c595bf6..968557c9c4686 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -19,12 +19,29 @@ package org.apache.spark.sql.hive import java.io.{OutputStream, PrintStream} +import scala.util.Try + +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.{AnalysisException, QueryTest} -import scala.util.Try -class ErrorPositionSuite extends QueryTest { +class ErrorPositionSuite extends QueryTest with BeforeAndAfter { + + before { + Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") + } + + positionTest("ambiguous attribute reference 1", + "SELECT a from dupAttributes", "a") + + positionTest("ambiguous attribute reference 2", + "SELECT a, b from dupAttributes", "a") + + positionTest("ambiguous attribute reference 3", + "SELECT b, a from dupAttributes", "a") positionTest("unresolved attribute 1", "SELECT x FROM src", "x") @@ -127,6 +144,10 @@ class ErrorPositionSuite extends QueryTest { val error = intercept[AnalysisException] { quietly(sql(query)) } + + assert(!error.getMessage.contains("Seq(")) + assert(!error.getMessage.contains("List(")) + val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect { case (l, i) if l.contains(token) => (l, i + 1) }.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query")) From cbeaf9ebab31a0bcbca884d4db7a791fd9edbff3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Mar 2015 14:08:20 -0700 Subject: [PATCH 014/171] [SPARK-6376][SQL] Avoid eliminating subqueries until optimization Previously it was okay to throw away subqueries after analysis, as we would never try to use that tree for resolution again. However, with eager analysis in `DataFrame`s this can cause errors for queries such as: ```scala val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count() ``` As a result, in this PR we defer the elimination of subqueries until the optimization phase. Author: Michael Armbrust Closes #5160 from marmbrus/subqueriesInDfs and squashes the following commits: a9bb262 [Michael Armbrust] Update Optimizer.scala 27d25bf [Michael Armbrust] fix hive tests 9137e03 [Michael Armbrust] add type 81cd597 [Michael Armbrust] Avoid eliminating subqueries until optimization --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +--- .../apache/spark/sql/catalyst/dsl/package.scala | 4 ++-- .../spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++++ .../sql/catalyst/plans/logical/LogicalPlan.scala | 16 ++++++++++------ .../sql/catalyst/analysis/AnalysisSuite.scala | 8 ++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ .../scala/org/apache/spark/sql/JoinSuite.scala | 4 ++-- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- 9 files changed, 34 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c93af79795bc7..13d2ae4c6fcb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -64,9 +64,7 @@ class Analyzer(catalog: Catalog, UnresolvedHavingClauseAttributes :: TrimGroupingAliases :: typeCoercionRules ++ - extendedResolutionRules : _*), - Batch("Remove SubQueries", fixedPoint, - EliminateSubQueries) + extendedResolutionRules : _*) ) /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 51a09ac0e1249..7f5f617812919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -289,7 +289,7 @@ package object dsl { InsertIntoTable( analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) - def analyze = analysis.SimpleAnalyzer(logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan)) } object plans { // scalastyle:ignore diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1a75fcf3545bd..74edaacc4f609 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -32,6 +33,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] object DefaultOptimizer extends Optimizer { val batches = + // SubQueries are only needed for analysis and can be removed before execution. + Batch("Remove SubQueries", FixedPoint(100), + EliminateSubQueries) :: Batch("Combine Limits", FixedPoint(100), CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0f8b144ccc113..b01a61d7bf8d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode @@ -73,12 +73,16 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * can do better should override this function. */ def sameResult(plan: LogicalPlan): Boolean = { - plan.getClass == this.getClass && - plan.children.size == children.size && { - logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]") - cleanArgs == plan.cleanArgs + val cleanLeft = EliminateSubQueries(this) + val cleanRight = EliminateSubQueries(plan) + + cleanLeft.getClass == cleanRight.getClass && + cleanLeft.children.size == cleanRight.children.size && { + logDebug( + s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") + cleanRight.cleanArgs == cleanLeft.cleanArgs } && - (plan.children, children).zipped.forall(_ sameResult _) + (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) } /** Args that have cleaned such that differences in expression id should not affect equality */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 359aec4a7b5ab..756cd36f05c8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -32,9 +32,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseInsensitiveCatalog = new SimpleCatalog(false) val caseSensitiveAnalyzer = - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } val caseInsensitiveAnalyzer = - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } val checkAnalysis = new CheckAnalysis diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ff441ef26f9c0..c30ed694a62f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -108,6 +108,13 @@ class DataFrameSuite extends QueryTest { ) } + test("self join with aliases") { + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + test("explode") { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index dd0948ad824be..e4dee87849fd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -34,7 +34,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -109,7 +109,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index ff2e6ea9ea51d..e5ad0bf552073 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -579,7 +579,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row(3) :: Row(4) :: Nil ) - table("test_parquet_ctas").queryExecution.analyzed match { + table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(p: ParquetRelation2) => // OK case _ => fail( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index d891c4e8903d9..8a31bd03092d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -292,7 +292,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { Seq(Row(1, "str1")) ) - table("test_parquet_ctas").queryExecution.analyzed match { + table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(p: ParquetRelation2) => // OK case _ => fail( From a8f51b82968147abebbe61b8b68b066d21a0c6e6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Mar 2015 14:10:56 -0700 Subject: [PATCH 015/171] [SPARK-6458][SQL] Better error messages for invalid data sources Avoid unclear match errors and use `AnalysisException`. Author: Michael Armbrust Closes #5158 from marmbrus/dataSourceError and squashes the following commits: af9f82a [Michael Armbrust] Yins comment 90c6ba4 [Michael Armbrust] Better error messages for invalid data sources --- .../scala/org/apache/spark/sql/sources/ddl.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index d2e807d3a69b6..eb46b46ca5bf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -21,7 +21,7 @@ import scala.language.existentials import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} +import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -204,19 +204,25 @@ private[sql] object ResolvedDataSource { provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) + def className = clazz.getCanonicalName val relation = userSpecifiedSchema match { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) case dataSource: org.apache.spark.sql.sources.RelationProvider => - sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } case None => clazz.newInstance() match { case dataSource: RelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } } new ResolvedDataSource(clazz, relation) From 7215aa745590a3eec9c1ff35d28194235a550db7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 24 Mar 2015 14:38:20 -0700 Subject: [PATCH 016/171] [SPARK-6209] Clean up connections in ExecutorClassLoader after failing to load classes (master branch PR) ExecutorClassLoader does not ensure proper cleanup of network connections that it opens. If it fails to load a class, it may leak partially-consumed InputStreams that are connected to the REPL's HTTP class server, causing that server to exhaust its thread pool, which can cause the entire job to hang. See [SPARK-6209](https://issues.apache.org/jira/browse/SPARK-6209) for more details, including a bug reproduction. This patch fixes this issue by ensuring proper cleanup of these resources. It also adds logging for unexpected error cases. This PR is an extended version of #4935 and adds a regression test. Author: Josh Rosen Closes #4944 from JoshRosen/executorclassloader-leak-master-branch and squashes the following commits: e0e3c25 [Josh Rosen] Wrap try block around getReponseCode; re-enable keep-alive by closing error stream 961c284 [Josh Rosen] Roll back changes that were added to get the regression test to fail 7ee2261 [Josh Rosen] Add a failing regression test e2d70a3 [Josh Rosen] Properly clean up after errors in ExecutorClassLoader --- repl/pom.xml | 5 ++ .../spark/repl/ExecutorClassLoader.scala | 85 +++++++++++++++---- .../spark/repl/ExecutorClassLoaderSuite.scala | 70 ++++++++++++++- 3 files changed, 140 insertions(+), 20 deletions(-) diff --git a/repl/pom.xml b/repl/pom.xml index edfa1c7f2c29c..03053b4c3b287 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -84,6 +84,11 @@ scalacheck_${scala.binary.version} test + + org.mockito + mockito-all + test + diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 9805609120005..004941d5f50ae 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,9 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException} -import java.net.{URI, URL, URLEncoder} -import java.util.concurrent.{Executors, ExecutorService} +import java.io.{IOException, ByteArrayOutputStream, InputStream} +import java.net.{HttpURLConnection, URI, URL, URLEncoder} + +import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val parentLoader = new ParentClassLoader(parent) + // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes + private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 + // Hadoop FileSystem object for our URI, if it isn't using HTTP var fileSystem: FileSystem = { if (Set("http", "https", "ftp").contains(uri.getScheme)) { @@ -71,30 +75,66 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)) + } + val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), + SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] + // Set the connection timeouts (for testing purposes) + if (httpUrlConnectionTimeoutMillis != -1) { + connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) + connection.setReadTimeout(httpUrlConnectionTimeoutMillis) + } + connection.connect() + try { + if (connection.getResponseCode != 200) { + // Close the error stream so that the connection is eligible for re-use + try { + connection.getErrorStream.close() + } catch { + case ioe: IOException => + logError("Exception while closing error stream", ioe) + } + throw new ClassNotFoundException(s"Class file not found at URL $url") + } else { + connection.getInputStream + } + } catch { + case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => + connection.disconnect() + throw e + } + } + + private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + val path = new Path(directory, pathInDirectory) + if (fileSystem.exists(path)) { + fileSystem.open(path) + } else { + throw new ClassNotFoundException(s"Class file not found at path $path") + } + } + def findClassLocally(name: String): Option[Class[_]] = { + val pathInDirectory = name.replace('.', '/') + ".class" + var inputStream: InputStream = null try { - val pathInDirectory = name.replace('.', '/') + ".class" - val inputStream = { + inputStream = { if (fileSystem != null) { - fileSystem.open(new Path(directory, pathInDirectory)) + getClassFileInputStreamFromFileSystem(pathInDirectory) } else { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - - Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager) - .getInputStream + getClassFileInputStreamFromHttpServer(pathInDirectory) } } val bytes = readAndTransformClass(name, inputStream) - inputStream.close() Some(defineClass(name, bytes, 0, bytes.length)) } catch { - case e: FileNotFoundException => + case e: ClassNotFoundException => // We did not find the class logDebug(s"Did not load class $name from REPL class server at $uri", e) None @@ -102,6 +142,15 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Something bad happened while checking if the class exists logError(s"Failed to check existence of class $name on REPL class server at $uri", e) None + } finally { + if (inputStream != null) { + try { + inputStream.close() + } catch { + case e: Exception => + logError("Exception while closing inputStream", e) + } + } } } diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 6a79e76a34db8..c709cde740748 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -20,13 +20,25 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite +import org.scalatest.concurrent.Interruptor +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito._ -import org.apache.spark.{SparkConf, TestUtils} +import org.apache.spark._ import org.apache.spark.util.Utils -class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { +class ExecutorClassLoaderSuite + extends FunSuite + with BeforeAndAfterAll + with MockitoSugar + with Logging { val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") @@ -34,6 +46,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { var tempDir2: File = _ var url1: String = _ var urls2: Array[URL] = _ + var classServer: HttpServer = _ override def beforeAll() { super.beforeAll() @@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def afterAll() { super.afterAll() + if (classServer != null) { + classServer.stop() + } Utils.deleteRecursively(tempDir1) Utils.deleteRecursively(tempDir2) + SparkEnv.set(null) } test("child first") { @@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { } } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { + // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class + // from the driver's class server would leak a HTTP connection, causing the class server's + // thread / connection pool to be exhausted. + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + classServer = new HttpServer(conf, tempDir1, securityManager) + classServer.start() + // ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this + val mockEnv = mock[SparkEnv] + when(mockEnv.securityManager).thenReturn(securityManager) + SparkEnv.set(mockEnv) + // Create an ExecutorClassLoader that's configured to load classes from the HTTP server + val parentLoader = new URLClassLoader(Array.empty, null) + val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + classLoader.httpUrlConnectionTimeoutMillis = 500 + // Check that this class loader can actually load classes that exist + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + // Try to perform a full GC now, since GC during the test might mask resource leaks + System.gc() + // When the original bug occurs, the test thread becomes blocked in a classloading call + // and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to + // shut down the HTTP server when the test times out + val interruptor: Interruptor = new Interruptor { + override def apply(thread: Thread): Unit = { + classServer.stop() + classServer = null + thread.interrupt() + } + } + def tryAndFailToLoadABunchOfClasses(): Unit = { + // The number of trials here should be much larger than Jetty's thread / connection limit + // in order to expose thread or connection leaks + for (i <- 1 to 1000) { + if (Thread.currentThread().isInterrupted) { + throw new InterruptedException() + } + // Incorporate the iteration number into the class name in order to avoid any response + // caching that might be added in the future + intercept[ClassNotFoundException] { + classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance() + } + } + } + failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) + } + } From 73348012d4ce6c9db85dfb48d51026efe5051c73 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Mar 2015 16:03:55 -0700 Subject: [PATCH 017/171] [SPARK-6428][SQL] Added explicit types for all public methods in catalyst I think after this PR, we can finally turn the rule on. There are still some smaller ones that need to be fixed, but those are easier. Author: Reynold Xin Closes #5162 from rxin/catalyst-explicit-types and squashes the following commits: e7eac03 [Reynold Xin] [SPARK-6428][SQL] Added explicit types for all public methods in catalyst. --- .../sql/catalyst/AbstractSparkSQLParser.scala | 8 +- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/analysis/Catalog.scala | 22 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../catalyst/analysis/FunctionRegistry.scala | 14 +- .../spark/sql/catalyst/analysis/package.scala | 2 +- .../sql/catalyst/analysis/unresolved.scala | 78 +++---- .../spark/sql/catalyst/dsl/package.scala | 202 +++++++++--------- .../catalyst/expressions/AttributeMap.scala | 4 +- .../catalyst/expressions/AttributeSet.scala | 24 ++- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 10 +- .../sql/catalyst/expressions/Expression.scala | 14 +- .../sql/catalyst/expressions/Projection.scala | 52 ++--- .../spark/sql/catalyst/expressions/Rand.scala | 13 +- .../sql/catalyst/expressions/ScalaUdf.scala | 4 +- .../sql/catalyst/expressions/SortOrder.scala | 7 +- .../expressions/SpecificMutableRow.scala | 80 +++---- .../sql/catalyst/expressions/aggregates.scala | 165 +++++++------- .../sql/catalyst/expressions/arithmetic.scala | 66 +++--- .../expressions/codegen/CodeGenerator.scala | 2 +- .../catalyst/expressions/complexTypes.scala | 26 +-- .../expressions/decimalFunctions.scala | 12 +- .../sql/catalyst/expressions/generators.scala | 6 +- .../sql/catalyst/expressions/literals.scala | 13 +- .../expressions/namedExpressions.scala | 51 +++-- .../catalyst/expressions/nullFunctions.scala | 21 +- .../sql/catalyst/expressions/predicates.scala | 80 +++---- .../spark/sql/catalyst/expressions/rows.scala | 18 +- .../spark/sql/catalyst/expressions/sets.scala | 35 +-- .../expressions/stringOperations.scala | 37 ++-- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 14 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 6 +- .../plans/logical/basicOperators.scala | 46 ++-- .../catalyst/plans/logical/partitioning.scala | 11 +- .../plans/physical/partitioning.scala | 24 +-- .../spark/sql/catalyst/trees/TreeNode.scala | 27 ++- .../spark/sql/catalyst/trees/package.scala | 4 +- .../spark/sql/catalyst/util/package.scala | 6 +- 40 files changed, 626 insertions(+), 586 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 366be00473d1c..3823584287741 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -26,7 +26,7 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ private[sql] object KeywordNormalizer { - def apply(str: String) = str.toLowerCase() + def apply(str: String): String = str.toLowerCase() } private[sql] abstract class AbstractSparkSQLParser @@ -42,7 +42,7 @@ private[sql] abstract class AbstractSparkSQLParser } protected case class Keyword(str: String) { - def normalize = KeywordNormalizer(str) + def normalize: String = KeywordNormalizer(str) def parser: Parser[String] = normalize } @@ -81,7 +81,7 @@ private[sql] abstract class AbstractSparkSQLParser class SqlLexical extends StdLexical { case class FloatLit(chars: String) extends Token { - override def toString = chars + override def toString: String = chars } /* This is a work around to support the lazy setting */ @@ -120,7 +120,7 @@ class SqlLexical extends StdLexical { | failure("illegal character") ) - override def identChar = letter | elem('_') + override def identChar: Parser[Elem] = letter | elem('_') override def whitespace: Parser[Any] = ( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 13d2ae4c6fcb6..44eceb0b372e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -168,7 +168,7 @@ class Analyzer(catalog: Catalog, * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - def getTable(u: UnresolvedRelation) = { + def getTable(u: UnresolvedRelation): LogicalPlan = { try { catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 9e6e2912e0622..5eb7dff0cede8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -86,12 +86,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables += ((getDbTableName(tableIdent), plan)) } - override def unregisterTable(tableIdentifier: Seq[String]) = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) tables -= getDbTableName(tableIdent) } - override def unregisterAllTables() = { + override def unregisterAllTables(): Unit = { tables.clear() } @@ -147,8 +147,8 @@ trait OverrideCatalog extends Catalog { } abstract override def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None): LogicalPlan = { + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val overriddenTable = overrides.get(getDBTable(tableIdent)) val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) @@ -205,15 +205,15 @@ trait OverrideCatalog extends Catalog { */ object EmptyCatalog extends Catalog { - val caseSensitive: Boolean = true + override val caseSensitive: Boolean = true - def tableExists(tableIdentifier: Seq[String]): Boolean = { + override def tableExists(tableIdentifier: Seq[String]): Boolean = { throw new UnsupportedOperationException } - def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None) = { + override def lookupRelation( + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { throw new UnsupportedOperationException } @@ -221,11 +221,11 @@ object EmptyCatalog extends Catalog { throw new UnsupportedOperationException } - def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } - def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 425e1e41cbf21..40472a1cbb3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -33,7 +33,7 @@ class CheckAnalysis { */ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil - def failAnalysis(msg: String) = { + def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9f334f6d42ad1..c43ea55899695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -35,7 +35,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry { val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } @@ -47,7 +47,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry { class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry { val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } @@ -61,13 +61,15 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr * functions are already filled in and the analyser needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - def registerFunction(name: String, builder: FunctionBuilder) = ??? + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { + throw new UnsupportedOperationException + } - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - def caseSensitive: Boolean = ??? + override def caseSensitive: Boolean = throw new UnsupportedOperationException } /** @@ -76,7 +78,7 @@ object EmptyFunctionRegistry extends FunctionRegistry { * TODO move this into util folder? */ object StringKeyHashMap { - def apply[T](caseSensitive: Boolean) = caseSensitive match { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { case false => new StringKeyHashMap[T](_.toLowerCase) case true => new StringKeyHashMap[T](identity) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index a7d3a8ee7deb3..c61c395cb4bb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -38,7 +38,7 @@ package object analysis { implicit class AnalysisErrorAt(t: TreeNode[_]) { /** Fails the analysis at the point where a specific tree node was parsed. */ - def failAnalysis(msg: String) = { + def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index ad5172c0349eb..300e9ba187bc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.types.DataType /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -38,9 +39,10 @@ case class UnresolvedRelation( alias: Option[String] = None) extends LeafNode { /** Returns a `.` separated name for this relation. */ - def tableName = tableIdentifier.mkString(".") + def tableName: String = tableIdentifier.mkString(".") + + override def output: Seq[Attribute] = Nil - override def output = Nil override lazy val resolved = false } @@ -48,16 +50,16 @@ case class UnresolvedRelation( * Holds the name of an attribute that has yet to be resolved. */ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = UnresolvedAttribute(name) + override def newInstance(): UnresolvedAttribute = this + override def withNullability(newNullability: Boolean): UnresolvedAttribute = this + override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -67,16 +69,16 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo } case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { - override def dataType = throw new UnresolvedException(this, "dataType") - override def foldable = throw new UnresolvedException(this, "foldable") - override def nullable = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"'$name(${children.mkString(",")})" + override def toString: String = s"'$name(${children.mkString(",")})" } /** @@ -86,17 +88,17 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E trait Star extends Attribute with trees.LeafNode[Expression] { self: Product => - override def name = throw new UnresolvedException(this, "name") - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def name: String = throw new UnresolvedException(this, "name") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = this + override def newInstance(): Star = this + override def withNullability(newNullability: Boolean): Star = this + override def withQualifiers(newQualifiers: Seq[String]): Star = this + override def withName(newName: String): Star = this // Star gets expanded at runtime so we never evaluate a Star. override def eval(input: Row = null): EvaluatedType = @@ -129,7 +131,7 @@ case class UnresolvedStar(table: Option[String]) extends Star { } } - override def toString = table.map(_ + ".").getOrElse("") + "*" + override def toString: String = table.map(_ + ".").getOrElse("") + "*" } /** @@ -144,25 +146,25 @@ case class UnresolvedStar(table: Option[String]) extends Star { case class MultiAlias(child: Expression, names: Seq[String]) extends Attribute with trees.UnaryNode[Expression] { - override def name = throw new UnresolvedException(this, "name") + override def name: String = throw new UnresolvedException(this, "name") - override def exprId = throw new UnresolvedException(this, "exprId") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this + override def newInstance(): MultiAlias = this - override def withNullability(newNullability: Boolean) = this + override def withNullability(newNullability: Boolean): MultiAlias = this - override def withQualifiers(newQualifiers: Seq[String]) = this + override def withQualifiers(newQualifiers: Seq[String]): MultiAlias = this - override def withName(newName: String) = this + override def withName(newName: String): MultiAlias = this override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") @@ -179,17 +181,17 @@ case class MultiAlias(child: Expression, names: Seq[String]) */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions - override def toString = expressions.mkString("ResolvedStar(", ", ", ")") + override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression { - override def dataType = throw new UnresolvedException(this, "dataType") - override def foldable = throw new UnresolvedException(this, "foldable") - override def nullable = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"$child.$fieldName" + override def toString: String = s"$child.$fieldName" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 7f5f617812919..145f062dd6817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -61,60 +61,60 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- = UnaryMinus(expr) - def unary_! = Not(expr) - def unary_~ = BitwiseNot(expr) - - def + (other: Expression) = Add(expr, other) - def - (other: Expression) = Subtract(expr, other) - def * (other: Expression) = Multiply(expr, other) - def / (other: Expression) = Divide(expr, other) - def % (other: Expression) = Remainder(expr, other) - def & (other: Expression) = BitwiseAnd(expr, other) - def | (other: Expression) = BitwiseOr(expr, other) - def ^ (other: Expression) = BitwiseXor(expr, other) - - def && (other: Expression) = And(expr, other) - def || (other: Expression) = Or(expr, other) - - def < (other: Expression) = LessThan(expr, other) - def <= (other: Expression) = LessThanOrEqual(expr, other) - def > (other: Expression) = GreaterThan(expr, other) - def >= (other: Expression) = GreaterThanOrEqual(expr, other) - def === (other: Expression) = EqualTo(expr, other) - def <=> (other: Expression) = EqualNullSafe(expr, other) - def !== (other: Expression) = Not(EqualTo(expr, other)) - - def in(list: Expression*) = In(expr, list) - - def like(other: Expression) = Like(expr, other) - def rlike(other: Expression) = RLike(expr, other) - def contains(other: Expression) = Contains(expr, other) - def startsWith(other: Expression) = StartsWith(expr, other) - def endsWith(other: Expression) = EndsWith(expr, other) - def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def unary_- : Expression= UnaryMinus(expr) + def unary_! : Predicate = Not(expr) + def unary_~ : Expression = BitwiseNot(expr) + + def + (other: Expression): Expression = Add(expr, other) + def - (other: Expression): Expression = Subtract(expr, other) + def * (other: Expression): Expression = Multiply(expr, other) + def / (other: Expression): Expression = Divide(expr, other) + def % (other: Expression): Expression = Remainder(expr, other) + def & (other: Expression): Expression = BitwiseAnd(expr, other) + def | (other: Expression): Expression = BitwiseOr(expr, other) + def ^ (other: Expression): Expression = BitwiseXor(expr, other) + + def && (other: Expression): Predicate = And(expr, other) + def || (other: Expression): Predicate = Or(expr, other) + + def < (other: Expression): Predicate = LessThan(expr, other) + def <= (other: Expression): Predicate = LessThanOrEqual(expr, other) + def > (other: Expression): Predicate = GreaterThan(expr, other) + def >= (other: Expression): Predicate = GreaterThanOrEqual(expr, other) + def === (other: Expression): Predicate = EqualTo(expr, other) + def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) + def !== (other: Expression): Predicate = Not(EqualTo(expr, other)) + + def in(list: Expression*): Expression = In(expr, list) + + def like(other: Expression): Expression = Like(expr, other) + def rlike(other: Expression): Expression = RLike(expr, other) + def contains(other: Expression): Expression = Contains(expr, other) + def startsWith(other: Expression): Expression = StartsWith(expr, other) + def endsWith(other: Expression): Expression = EndsWith(expr, other) + def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def isNull = IsNull(expr) - def isNotNull = IsNotNull(expr) + def isNull: Predicate = IsNull(expr) + def isNotNull: Predicate = IsNotNull(expr) - def getItem(ordinal: Expression) = GetItem(expr, ordinal) - def getField(fieldName: String) = UnresolvedGetField(expr, fieldName) + def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal) + def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName) - def cast(to: DataType) = Cast(expr, to) + def cast(to: DataType): Expression = Cast(expr, to) - def asc = SortOrder(expr, Ascending) - def desc = SortOrder(expr, Descending) + def asc: SortOrder = SortOrder(expr, Ascending) + def desc: SortOrder = SortOrder(expr, Descending) - def as(alias: String) = Alias(expr, alias)() - def as(alias: Symbol) = Alias(expr, alias.name)() + def as(alias: String): NamedExpression = Alias(expr, alias)() + def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } trait ExpressionConversions { implicit class DslExpression(e: Expression) extends ImplicitOperators { - def expr = e + def expr: Expression = e } implicit def booleanToLiteral(b: Boolean): Literal = Literal(b) @@ -144,94 +144,100 @@ package object dsl { } } - def sum(e: Expression) = Sum(e) - def sumDistinct(e: Expression) = SumDistinct(e) - def count(e: Expression) = Count(e) - def countDistinct(e: Expression*) = CountDistinct(e) - def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) - def avg(e: Expression) = Average(e) - def first(e: Expression) = First(e) - def last(e: Expression) = Last(e) - def min(e: Expression) = Min(e) - def max(e: Expression) = Max(e) - def upper(e: Expression) = Upper(e) - def lower(e: Expression) = Lower(e) - def sqrt(e: Expression) = Sqrt(e) - def abs(e: Expression) = Abs(e) - - implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + def sum(e: Expression): Expression = Sum(e) + def sumDistinct(e: Expression): Expression = SumDistinct(e) + def count(e: Expression): Expression = Count(e) + def countDistinct(e: Expression*): Expression = CountDistinct(e) + def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = + ApproxCountDistinct(e, rsd) + def avg(e: Expression): Expression = Average(e) + def first(e: Expression): Expression = First(e) + def last(e: Expression): Expression = Last(e) + def min(e: Expression): Expression = Min(e) + def max(e: Expression): Expression = Max(e) + def upper(e: Expression): Expression = Upper(e) + def lower(e: Expression): Expression = Lower(e) + def sqrt(e: Expression): Expression = Sqrt(e) + def abs(e: Expression): Expression = Abs(e) + + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { override def expr: Expression = Literal(s) - def attr = analysis.UnresolvedAttribute(s) + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) } abstract class ImplicitAttribute extends ImplicitOperators { def s: String - def expr = attr - def attr = analysis.UnresolvedAttribute(s) + def expr: UnresolvedAttribute = attr + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) /** Creates a new AttributeReference of type boolean */ - def boolean = AttributeReference(s, BooleanType, nullable = true)() + def boolean: AttributeReference = AttributeReference(s, BooleanType, nullable = true)() /** Creates a new AttributeReference of type byte */ - def byte = AttributeReference(s, ByteType, nullable = true)() + def byte: AttributeReference = AttributeReference(s, ByteType, nullable = true)() /** Creates a new AttributeReference of type short */ - def short = AttributeReference(s, ShortType, nullable = true)() + def short: AttributeReference = AttributeReference(s, ShortType, nullable = true)() /** Creates a new AttributeReference of type int */ - def int = AttributeReference(s, IntegerType, nullable = true)() + def int: AttributeReference = AttributeReference(s, IntegerType, nullable = true)() /** Creates a new AttributeReference of type long */ - def long = AttributeReference(s, LongType, nullable = true)() + def long: AttributeReference = AttributeReference(s, LongType, nullable = true)() /** Creates a new AttributeReference of type float */ - def float = AttributeReference(s, FloatType, nullable = true)() + def float: AttributeReference = AttributeReference(s, FloatType, nullable = true)() /** Creates a new AttributeReference of type double */ - def double = AttributeReference(s, DoubleType, nullable = true)() + def double: AttributeReference = AttributeReference(s, DoubleType, nullable = true)() /** Creates a new AttributeReference of type string */ - def string = AttributeReference(s, StringType, nullable = true)() + def string: AttributeReference = AttributeReference(s, StringType, nullable = true)() /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() + def date: AttributeReference = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + def decimal: AttributeReference = + AttributeReference(s, DecimalType.Unlimited, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal(precision: Int, scale: Int) = + def decimal(precision: Int, scale: Int): AttributeReference = AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ - def timestamp = AttributeReference(s, TimestampType, nullable = true)() + def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() /** Creates a new AttributeReference of type binary */ - def binary = AttributeReference(s, BinaryType, nullable = true)() + def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() /** Creates a new AttributeReference of type array */ - def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)() + def array(dataType: DataType): AttributeReference = + AttributeReference(s, ArrayType(dataType), nullable = true)() /** Creates a new AttributeReference of type map */ def map(keyType: DataType, valueType: DataType): AttributeReference = map(MapType(keyType, valueType)) - def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)() + + def map(mapType: MapType): AttributeReference = + AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) - def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)() + def struct(structType: StructType): AttributeReference = + AttributeReference(s, structType, nullable = true)() } implicit class DslAttribute(a: AttributeReference) { - def notNull = a.withNullability(false) - def nullable = a.withNullability(true) + def notNull: AttributeReference = a.withNullability(false) + def nullable: AttributeReference = a.withNullability(true) // Protobuf terminology - def required = a.withNullability(false) + def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) + def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } @@ -241,23 +247,23 @@ package object dsl { abstract class LogicalPlanFunctions { def logicalPlan: LogicalPlan - def select(exprs: NamedExpression*) = Project(exprs, logicalPlan) + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression) = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, - condition: Option[Expression] = None) = + condition: Option[Expression] = None): LogicalPlan = Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { val aliasedExprs = aggregateExprs.map { case ne: NamedExpression => ne case e => Alias(e, e.toString)() @@ -265,27 +271,27 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = + def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) def sample( fraction: Double, withReplacement: Boolean = true, - seed: Int = (math.random * 1000).toInt) = + seed: Int = (math.random * 1000).toInt): LogicalPlan = Sample(fraction, withReplacement, seed, logicalPlan) def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None) = + alias: Option[String] = None): LogicalPlan = Generate(generator, join, outer, None, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false) = + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) @@ -294,12 +300,14 @@ package object dsl { object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String) = WriteToFile(path, logicalPlan) + def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) } } case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + def call(args: Expression*): ScalaUdf = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } } // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 82e760b6c6916..96a11e352ec50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,7 +23,9 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { - def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { + new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + } } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index adaeab0b5c027..f9ae85a5cfc1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -19,27 +19,27 @@ package org.apache.spark.sql.catalyst.expressions protected class AttributeEquals(val a: Attribute) { - override def hashCode() = a match { + override def hashCode(): Int = a match { case ar: AttributeReference => ar.exprId.hashCode() case a => a.hashCode() } - override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { + override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match { case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId case (a1, a2) => a1 == a2 } } object AttributeSet { - def apply(a: Attribute) = - new AttributeSet(Set(new AttributeEquals(a))) + def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ - def apply(baseSet: Seq[Expression]) = + def apply(baseSet: Seq[Expression]): AttributeSet = { new AttributeSet( baseSet .flatMap(_.references) .map(new AttributeEquals(_)).toSet) + } } /** @@ -57,7 +57,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] with Serializable { /** Returns true if the members of this AttributeSet and other are the same. */ - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) case _ => false } @@ -81,32 +81,34 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in * `other`. */ - def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + def subsetOf(other: AttributeSet): Boolean = baseSet.subsetOf(other.baseSet) /** * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found * in `other`. */ - def --(other: Traversable[NamedExpression]) = + def --(other: Traversable[NamedExpression]): AttributeSet = new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found * in `other`. */ - def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + def ++(other: AttributeSet): AttributeSet = new AttributeSet(baseSet ++ other.baseSet) /** * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to * true. */ - override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + override def filter(f: Attribute => Boolean): AttributeSet = + new AttributeSet(baseSet.filter(ae => f(ae.a))) /** * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in * `this` and `other`. */ - def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) + def intersect(other: AttributeSet): AttributeSet = + new AttributeSet(baseSet.intersect(other.baseSet)) override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 76a9f08dea85f..2225621dbaabd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,7 +32,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - override def toString = s"input[$ordinal]" + override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b1bc858478ee1..9bde74ac22669 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -29,9 +29,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) - override def foldable = child.foldable + override def foldable: Boolean = child.foldable - override def nullable = forceNullable(child.dataType, dataType) || child.nullable + override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true @@ -103,7 +103,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - override def toString = s"CAST($child, $dataType)" + override def toString: String = s"CAST($child, $dataType)" type EvaluatedType = Any @@ -430,14 +430,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w object Cast { // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { + override def initialValue(): SimpleDateFormat = { new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } } // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { + override def initialValue(): SimpleDateFormat = { new SimpleDateFormat("yyyy-MM-dd") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6ad39b8372cfb..4e3bbc06a5b4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -65,7 +65,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. */ - def childrenResolved = !children.exists(!_.resolved) + def childrenResolved: Boolean = !children.exists(!_.resolved) /** * Returns a string representation of this expression that does not have developer centric @@ -84,9 +84,9 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express def symbol: String - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable - override def toString = s"($left $symbol $right)" + override def toString: String = s"($left $symbol $right)" } abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { @@ -104,8 +104,8 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio case class GroupExpression(children: Seq[Expression]) extends Expression { self: Product => type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = ??? - override def nullable = false - override def foldable = false - override def dataType = ??? + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = false + override def foldable: Boolean = false + override def dataType: DataType = throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index db5d897ee569f..c2866cd955409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -40,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { new GenericRow(outputArray) } - override def toString = s"Row => [${exprArray.mkString(",")}]" + override def toString: String = s"Row => [${exprArray.mkString(",")}]" } /** @@ -107,12 +107,12 @@ class JoinedRow extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -142,7 +142,7 @@ class JoinedRow extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -153,7 +153,7 @@ class JoinedRow extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -207,12 +207,12 @@ class JoinedRow2 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -242,7 +242,7 @@ class JoinedRow2 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -253,7 +253,7 @@ class JoinedRow2 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -301,12 +301,12 @@ class JoinedRow3 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -336,7 +336,7 @@ class JoinedRow3 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -347,7 +347,7 @@ class JoinedRow3 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -395,12 +395,12 @@ class JoinedRow4 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -430,7 +430,7 @@ class JoinedRow4 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -441,7 +441,7 @@ class JoinedRow4 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -489,12 +489,12 @@ class JoinedRow5 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -524,7 +524,7 @@ class JoinedRow5 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -535,7 +535,7 @@ class JoinedRow5 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index b2c6d3029031d..f5fea3f015dc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Random -import org.apache.spark.sql.types.DoubleType + +import org.apache.spark.sql.types.{DataType, DoubleType} case object Rand extends LeafExpression { - override def dataType = DoubleType - override def nullable = false + override def dataType: DataType = DoubleType + override def nullable: Boolean = false private[this] lazy val rand = new Random - override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType] + override def eval(input: Row = null): EvaluatedType = { + rand.nextDouble().asInstanceOf[EvaluatedType] + } - override def toString = "RAND()" + override def toString: String = "RAND()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 8a36c6810790d..1fd5ce342b2ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -29,9 +29,9 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi type EvaluatedType = Any - def nullable = true + override def nullable: Boolean = true - override def toString = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"scalaUDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d00b2ac09745c..83074eb1e6310 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.DataType abstract sealed class SortDirection case object Ascending extends SortDirection @@ -31,12 +32,12 @@ case object Descending extends SortDirection case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 21d714c9a8c3b..47b6f358ed1b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -62,126 +62,126 @@ abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any def update(v: Any) - def copy(): this.type + def copy(): MutableValue } final class MutableInt extends MutableValue { var value: Int = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Int] + value = v.asInstanceOf[Int] } - def copy() = { + override def copy(): MutableInt = { val newCopy = new MutableInt newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableInt] } } final class MutableFloat extends MutableValue { var value: Float = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Float] + value = v.asInstanceOf[Float] } - def copy() = { + override def copy(): MutableFloat = { val newCopy = new MutableFloat newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableFloat] } } final class MutableBoolean extends MutableValue { var value: Boolean = false - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Boolean] + value = v.asInstanceOf[Boolean] } - def copy() = { + override def copy(): MutableBoolean = { val newCopy = new MutableBoolean newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableBoolean] } } final class MutableDouble extends MutableValue { var value: Double = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Double] + value = v.asInstanceOf[Double] } - def copy() = { + override def copy(): MutableDouble = { val newCopy = new MutableDouble newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableDouble] } } final class MutableShort extends MutableValue { var value: Short = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Short] } - def copy() = { + override def copy(): MutableShort = { val newCopy = new MutableShort newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableShort] } } final class MutableLong extends MutableValue { var value: Long = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Long] } - def copy() = { + override def copy(): MutableLong = { val newCopy = new MutableLong newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableLong] } } final class MutableByte extends MutableValue { var value: Byte = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Byte] } - def copy() = { + override def copy(): MutableByte = { val newCopy = new MutableByte newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableByte] } } final class MutableAny extends MutableValue { var value: Any = _ - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Any] + value = v.asInstanceOf[Any] } - def copy() = { + override def copy(): MutableAny = { val newCopy = new MutableAny newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableAny] } } @@ -234,9 +234,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR if (value == null) setNullAt(ordinal) else values(ordinal).update(value) } - override def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) - override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5297d1e31246c..30da4faa3f1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -79,27 +79,29 @@ abstract class AggregateFunction /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - override def nullable = base.nullable - override def dataType = base.dataType + override def nullable: Boolean = base.nullable + override def dataType: DataType = base.dataType def update(input: Row): Unit // Do we really need this? - override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override def newInstance(): AggregateFunction = { + makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + } } case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MIN($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) } - override def newInstance() = new MinFunction(child, this) + override def newInstance(): MinFunction = new MinFunction(child, this) } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -121,16 +123,16 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MAX($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) } - override def newInstance() = new MaxFunction(child, this) + override def newInstance(): MaxFunction = new MaxFunction(child, this) } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -152,29 +154,29 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT($child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) } - override def newInstance() = new CountFunction(child, this) + override def newInstance(): CountFunction = new CountFunction(child, this) } case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) - override def children = expressions + override def children: Seq[Expression] = expressions - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance() = new CountDistinctFunction(expressions, this) + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(expressions), "partialSets")() SplitEvaluation( CombineSetsAndCount(partialSet.toAttribute), @@ -185,11 +187,11 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) - override def children = expressions - override def nullable = false - override def dataType = ArrayType(expressions.head.dataType) - override def toString = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance() = new CollectHashSetFunction(expressions, this) + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: ArrayType = ArrayType(expressions.head.dataType) + override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" + override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this) } case class CollectHashSetFunction( @@ -219,11 +221,13 @@ case class CollectHashSetFunction( case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { def this() = this(null) - override def children = inputSet :: Nil - override def nullable = false - override def dataType = LongType - override def toString = s"CombineAndCount($inputSet)" - override def newInstance() = new CombineSetsAndCountFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"CombineAndCount($inputSet)" + override def newInstance(): CombineSetsAndCountFunction = { + new CombineSetsAndCountFunction(inputSet, this) + } } case class CombineSetsAndCountFunction( @@ -249,27 +253,31 @@ case class CombineSetsAndCountFunction( case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: DataType = child.dataType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctPartitionFunction = { + new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + } } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctMergeFunction = { + new ApproxCountDistinctMergeFunction(child, this, relativeSD) + } } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" override def asPartial: SplitEvaluation = { val partialCount = @@ -280,14 +288,14 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) partialCount :: Nil) } - override def newInstance() = new CountDistinctFunction(child :: Nil, this) + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive case DecimalType.Unlimited => @@ -296,7 +304,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString = s"AVG($child)" + override def toString: String = s"AVG($child)" override def asPartial: SplitEvaluation = { child.dataType match { @@ -323,14 +331,14 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } } - override def newInstance() = new AverageFunction(child, this) + override def newInstance(): AverageFunction = new AverageFunction(child, this) } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -339,7 +347,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString = s"SUM($child)" + override def toString: String = s"SUM($child)" override def asPartial: SplitEvaluation = { child.dataType match { @@ -357,7 +365,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ } } - override def newInstance() = new SumFunction(child, this) + override def newInstance(): SumFunction = new SumFunction(child, this) } /** @@ -377,19 +385,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class CombineSum(child: Expression) extends AggregateExpression { def this() = this(null) - override def children = child :: Nil - override def nullable = true - override def dataType = child.dataType - override def toString = s"CombineSum($child)" - override def newInstance() = new CombineSumFunction(child, this) + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } case class SumDistinct(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { def this() = this(null) - override def nullable = true - override def dataType = child.dataType match { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -397,10 +405,10 @@ case class SumDistinct(child: Expression) case _ => child.dataType } - override def toString = s"SUM(DISTINCT ${child})" - override def newInstance() = new SumDistinctFunction(child, this) + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() SplitEvaluation( CombineSetsAndSum(partialSet.toAttribute, this), @@ -411,11 +419,13 @@ case class SumDistinct(child: Expression) case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { def this() = this(null, null) - override def children = inputSet :: Nil - override def nullable = true - override def dataType = base.dataType - override def toString = s"CombineAndSum($inputSet)" - override def newInstance() = new CombineSetsAndSumFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } } case class CombineSetsAndSumFunction( @@ -449,9 +459,9 @@ case class CombineSetsAndSumFunction( } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"FIRST($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" override def asPartial: SplitEvaluation = { val partialFirst = Alias(First(child), "PartialFirst")() @@ -459,14 +469,14 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod First(partialFirst.toAttribute), partialFirst :: Nil) } - override def newInstance() = new FirstFunction(child, this) + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references - override def nullable = true - override def dataType = child.dataType - override def toString = s"LAST($child)" + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" override def asPartial: SplitEvaluation = { val partialLast = Alias(Last(child), "PartialLast")() @@ -474,7 +484,7 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode Last(partialLast.toAttribute), partialLast :: Nil) } - override def newInstance() = new LastFunction(child, this) + override def newInstance(): LastFunction = new LastFunction(child, this) } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -713,6 +723,7 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg result = input } - override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row]) - else null + override def eval(input: Row): Any = { + if (result != null) expr.eval(result.asInstanceOf[Row]) else null + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 00b0d3c683fe2..1f6526ef66c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.types._ case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"-$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"-$child" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -47,10 +47,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = DoubleType - override def foldable = child.foldable - def nullable = true - override def toString = s"SQRT($child)" + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"SQRT($child)" lazy val numeric = child.dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -74,14 +74,14 @@ abstract class BinaryArithmetic extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + def nullable: Boolean = left.nullable || right.nullable override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType && !DecimalType.isFixed(left.dataType) - def dataType = { + def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") @@ -108,7 +108,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "+" + override def symbol: String = "+" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -131,7 +131,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "-" + override def symbol: String = "-" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -154,7 +154,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "*" + override def symbol: String = "*" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -177,9 +177,9 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "/" + override def symbol: String = "/" - override def nullable = true + override def nullable: Boolean = true lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div @@ -203,9 +203,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "%" + override def symbol: String = "%" - override def nullable = true + override def nullable: Boolean = true lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] @@ -232,7 +232,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet * A function that calculates bitwise and(&) of two numbers. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "&" + override def symbol: String = "&" lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -253,7 +253,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * A function that calculates bitwise or(|) of two numbers. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "|" + override def symbol: String = "|" lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -274,7 +274,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * A function that calculates bitwise xor(^) of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "^" + override def symbol: String = "^" lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -297,10 +297,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme case class BitwiseNot(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"~$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"~$child" lazy val not: (Any) => Any = dataType match { case ByteType => @@ -327,17 +327,17 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { case class MaxOf(left: Expression, right: Expression) extends Expression { type EvaluatedType = Any - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable - override def nullable = left.nullable && right.nullable + override def nullable: Boolean = left.nullable && right.nullable - override def children = left :: right :: Nil + override def children: Seq[Expression] = left :: right :: Nil override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType - override def dataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") @@ -366,7 +366,7 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { } } - override def toString = s"MaxOf($left, $right)" + override def toString: String = s"MaxOf($left, $right)" } /** @@ -375,10 +375,10 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { case class Abs(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"Abs($child)" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"Abs($child)" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e48b8cde20eda..d1abf3c0b64a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -91,7 +91,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() - def timeMs = (endTime - startTime).toDouble / 1000000 + def timeMs: Double = (endTime - startTime).toDouble / 1000000 logInfo(s"Code generated expression $in in $timeMs ms") result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 68051a2a2007e..3fd78db297462 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -27,12 +27,12 @@ import org.apache.spark.sql.types._ case class GetItem(child: Expression, ordinal: Expression) extends Expression { type EvaluatedType = Any - val children = child :: ordinal :: Nil + val children: Seq[Expression] = child :: ordinal :: Nil /** `Null` is returned for invalid ordinals. */ - override def nullable = true - override def foldable = child.foldable && ordinal.foldable + override def nullable: Boolean = true + override def foldable: Boolean = child.foldable && ordinal.foldable - def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case MapType(_, vt, _) => vt } @@ -40,7 +40,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { childrenResolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - override def toString = s"$child[$ordinal]" + override def toString: String = s"$child[$ordinal]" override def eval(input: Row): Any = { val value = child.eval(input) @@ -75,8 +75,8 @@ trait GetField extends UnaryExpression { self: Product => type EvaluatedType = Any - override def foldable = child.foldable - override def toString = s"$child.${field.name}" + override def foldable: Boolean = child.foldable + override def toString: String = s"$child.${field.name}" def field: StructField } @@ -86,8 +86,8 @@ trait GetField extends UnaryExpression { */ case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField { - def dataType = field.dataType - override def nullable = child.nullable || field.nullable + override def dataType: DataType = field.dataType + override def nullable: Boolean = child.nullable || field.nullable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] @@ -101,8 +101,8 @@ case class StructGetField(child: Expression, field: StructField, ordinal: Int) e case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean) extends GetField { - def dataType = ArrayType(field.dataType, containsNull) - override def nullable = child.nullable + override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[Row]] @@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -140,5 +140,5 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString = s"Array(${children.mkString(",")})" + override def toString: String = s"Array(${children.mkString(",")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 83d8c1d42bca4..adb94df7d1c7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -24,9 +24,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { override type EvaluatedType = Any override def dataType: DataType = LongType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"UnscaledValue($child)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"UnscaledValue($child)" override def eval(input: Row): Any = { val childResult = child.eval(input) @@ -43,9 +43,9 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"MakeDecimal($child,$precision,$scale)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"MakeDecimal($child,$precision,$scale)" override def eval(input: Row): Decimal = { val childResult = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 0983d274def3f..860b72fad38b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -45,7 +45,7 @@ abstract class Generator extends Expression { override lazy val dataType = ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) - override def nullable = false + override def nullable: Boolean = false /** * Should be overridden by specific generators. Called only once for each instance to ensure @@ -89,7 +89,7 @@ case class UserDefinedGenerator( function(inputRow(input)) } - override def toString = s"UserDefinedGenerator(${children.mkString(",")})" + override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" } /** @@ -130,5 +130,5 @@ case class Explode(attributeNames: Seq[String], child: Expression) } } - override def toString() = s"explode($child)" + override def toString: String = s"explode($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 9ff66563c8164..19f3fc9c2291a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -64,14 +64,13 @@ object IntegerLiteral { case class Literal(value: Any, dataType: DataType) extends LeafExpression { - override def foldable = true - def nullable = value == null + override def foldable: Boolean = true + override def nullable: Boolean = value == null - - override def toString = if (value != null) value.toString else "null" + override def toString: String = if (value != null) value.toString else "null" type EvaluatedType = Any - override def eval(input: Row):Any = value + override def eval(input: Row): Any = value } // TODO: Specialize @@ -79,9 +78,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean extends LeafExpression { type EvaluatedType = Any - def update(expression: Expression, input: Row) = { + def update(expression: Expression, input: Row): Unit = { value = expression.eval(input) } - override def eval(input: Row) = value + override def eval(input: Row): Any = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 08361d043b6ed..bcbcbeb31c7b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.trees.LeafNode import org.apache.spark.sql.types._ object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() - def newExprId = ExprId(curId.getAndIncrement()) + def newExprId: ExprId = ExprId(curId.getAndIncrement()) def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType) } @@ -79,13 +80,13 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => - override def references = AttributeSet(this) + override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def withName(newName: String): Attribute - def toAttribute = this + def toAttribute: Attribute = this def newInstance(): Attribute } @@ -112,10 +113,10 @@ case class Alias(child: Expression, name: String)( override type EvaluatedType = Any - override def eval(input: Row) = child.eval(input) + override def eval(input: Row): Any = child.eval(input) - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable override def metadata: Metadata = { explicitMetadata.getOrElse { child match { @@ -125,7 +126,7 @@ case class Alias(child: Expression, name: String)( } } - override def toAttribute = { + override def toAttribute: Attribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) } else { @@ -135,7 +136,9 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" - override protected final def otherCopyArgs = exprId :: qualifiers :: explicitMetadata :: Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + exprId :: qualifiers :: explicitMetadata :: Nil + } override def equals(other: Any): Boolean = other match { case a: Alias => @@ -166,7 +169,7 @@ case class AttributeReference( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } @@ -180,7 +183,7 @@ case class AttributeReference( h } - override def newInstance() = + override def newInstance(): AttributeReference = AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) /** @@ -205,7 +208,7 @@ case class AttributeReference( /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ - override def withQualifiers(newQualifiers: Seq[String]) = { + override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = { if (newQualifiers.toSet == qualifiers.toSet) { this } else { @@ -227,20 +230,22 @@ case class AttributeReference( case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { type EvaluatedType = Any - override def toString = name - - override def withNullability(newNullability: Boolean): Attribute = ??? - override def newInstance(): Attribute = ??? - override def withQualifiers(newQualifiers: Seq[String]): Attribute = ??? - override def withName(newName: String): Attribute = ??? - override def qualifiers: Seq[String] = ??? - override def exprId: ExprId = ??? - override def eval(input: Row): EvaluatedType = ??? - override def nullable: Boolean = ??? + override def toString: String = name + + override def withNullability(newNullability: Boolean): Attribute = + throw new UnsupportedOperationException + override def newInstance(): Attribute = throw new UnsupportedOperationException + override def withQualifiers(newQualifiers: Seq[String]): Attribute = + throw new UnsupportedOperationException + override def withName(newName: String): Attribute = throw new UnsupportedOperationException + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + override def exprId: ExprId = throw new UnsupportedOperationException + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } object VirtualColumn { - val groupingIdName = "grouping__id" - def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdName: String = "grouping__id" + def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 08b982bc671e7..d1f3d4f4ee9ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -19,22 +19,23 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - def nullable = !children.exists(!_.nullable) + override def nullable: Boolean = !children.exists(!_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) // Only resolved if all the children are of the same type. override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) - override def toString = s"Coalesce(${children.mkString(",")})" + override def toString: String = s"Coalesce(${children.mkString(",")})" - def dataType = if (resolved) { + def dataType: DataType = if (resolved) { children.head.dataType } else { val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") @@ -54,20 +55,20 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false override def eval(input: Row): Any = { child.eval(input) == null } - override def toString = s"IS NULL $child" + override def toString: String = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false - override def toString = s"IS NOT NULL $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false + override def toString: String = s"IS NOT NULL $child" override def eval(input: Row): Any = { child.eval(input) != null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 0024ef92c0452..7e47cb3fffe12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{BinaryType, BooleanType, NativeType} +import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType} object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -34,7 +34,7 @@ object InterpretedPredicate { trait Predicate extends Expression { self: Product => - def dataType = BooleanType + override def dataType: DataType = BooleanType type EvaluatedType = Any } @@ -72,13 +72,13 @@ trait PredicateHelper { abstract class BinaryPredicate extends BinaryExpression with Predicate { self: Product => - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable } case class Not(child: Expression) extends UnaryExpression with Predicate { - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"NOT $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"NOT $child" override def eval(input: Row): Any = { child.eval(input) match { @@ -92,10 +92,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate { * Evaluates to `true` if `list` contains `value`. */ case class In(value: Expression, list: Seq[Expression]) extends Predicate { - def children = value +: list + override def children: Seq[Expression] = value +: list - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value IN ${list.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: Row): Any = { val evaluatedValue = value.eval(input) @@ -110,10 +110,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case class InSet(value: Expression, hset: Set[Any]) extends Predicate { - def children = value :: Nil + override def children: Seq[Expression] = value :: Nil - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" override def eval(input: Row): Any = { hset.contains(value.eval(input)) @@ -121,7 +121,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "&&" + override def symbol: String = "&&" override def eval(input: Row): Any = { val l = left.eval(input) @@ -143,7 +143,7 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { } case class Or(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "||" + override def symbol: String = "||" override def eval(input: Row): Any = { val l = left.eval(input) @@ -169,7 +169,8 @@ abstract class BinaryComparison extends BinaryPredicate { } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "=" + override def symbol: String = "=" + override def eval(input: Row): Any = { val l = left.eval(input) if (l == null) { @@ -185,8 +186,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=>" - override def nullable = false + override def symbol: String = "<=>" + + override def nullable: Boolean = false + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -201,9 +204,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<" + override def symbol: String = "<" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -216,7 +219,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -230,9 +233,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=" + override def symbol: String = "<=" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -245,7 +248,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -259,9 +262,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">" + override def symbol: String = ">" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -288,9 +291,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">=" + override def symbol: String = ">=" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -303,7 +306,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -317,13 +320,13 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends Expression { - def children = predicate :: trueValue :: falseValue :: Nil - override def nullable = trueValue.nullable || falseValue.nullable + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - def dataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException( this, @@ -342,7 +345,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def toString = s"if ($predicate) $trueValue else $falseValue" + override def toString: String = s"if ($predicate) $trueValue else $falseValue" } // scalastyle:off @@ -362,9 +365,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi // scalastyle:on case class CaseWhen(branches: Seq[Expression]) extends Expression { type EvaluatedType = Any - def children = branches - def dataType = { + override def children: Seq[Expression] = branches + + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") } @@ -379,12 +383,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { @transient private[this] lazy val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - override def nullable = { + override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } - override lazy val resolved = { + override lazy val resolved: Boolean = { if (!childrenResolved) { false } else { @@ -415,7 +419,7 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { res } - override def toString = { + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" case Seq(elseValue) => s" ELSE $elseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index f03d6f71a9fae..8bba26bc4cf7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,8 +44,8 @@ trait MutableRow extends Row { */ object EmptyRow extends Row { override def apply(i: Int): Any = throw new UnsupportedOperationException - override def toSeq = Seq.empty - override def length = 0 + override def toSeq: Seq[Any] = Seq.empty + override def length: Int = 0 override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException override def getInt(i: Int): Int = throw new UnsupportedOperationException override def getLong(i: Int): Long = throw new UnsupportedOperationException @@ -56,7 +56,7 @@ object EmptyRow extends Row { override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException - def copy() = this + override def copy(): Row = this } /** @@ -70,13 +70,13 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { def this(size: Int) = this(new Array[Any](size)) - override def toSeq = values.toSeq + override def toSeq: Seq[Any] = values.toSeq - override def length = values.length + override def length: Int = values.length - override def apply(i: Int) = values(i) + override def apply(i: Int): Any = values(i) - override def isNullAt(i: Int) = values(i) == null + override def isNullAt(i: Int): Boolean = values(i) == null override def getInt(i: Int): Int = { if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") @@ -167,7 +167,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { case _ => false } - def copy() = this + override def copy(): Row = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) @@ -194,7 +194,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } - override def copy() = new GenericRow(values.clone()) + override def copy(): Row = new GenericRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 3a5bdca1f07c3..35faa00782e80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -26,17 +26,17 @@ import org.apache.spark.util.collection.OpenHashSet case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any - def nullable = false + override def nullable: Boolean = false // We are currently only using these Expressions internally for aggregation. However, if we ever // expose these to users we'll want to create a proper type instead of hijacking ArrayType. - def dataType = ArrayType(elementType) + override def dataType: DataType = ArrayType(elementType) - def eval(input: Row): Any = { + override def eval(input: Row): Any = { new OpenHashSet[Any]() } - override def toString = s"new Set($dataType)" + override def toString: String = s"new Set($dataType)" } /** @@ -46,12 +46,13 @@ case class NewSet(elementType: DataType) extends LeafExpression { case class AddItemToSet(item: Expression, set: Expression) extends Expression { type EvaluatedType = Any - def children = item :: set :: Nil + override def children: Seq[Expression] = item :: set :: Nil - def nullable = set.nullable + override def nullable: Boolean = set.nullable - def dataType = set.dataType - def eval(input: Row): Any = { + override def dataType: DataType = set.dataType + + override def eval(input: Row): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -67,7 +68,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def toString = s"$set += $item" + override def toString: String = s"$set += $item" } /** @@ -77,13 +78,13 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable - def dataType = left.dataType + override def dataType: DataType = left.dataType - def symbol = "++=" + override def symbol: String = "++=" - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -109,16 +110,16 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres case class CountSet(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def nullable = child.nullable + override def nullable: Boolean = child.nullable - def dataType = LongType + override def dataType: DataType = LongType - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] if (childEval != null) { childEval.size.toLong } } - override def toString = s"$child.count()" + override def toString: String = s"$child.count()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f85ee0a9bb6d8..3cdca4e9dd2d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -33,8 +33,8 @@ trait StringRegexExpression { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - def nullable: Boolean = left.nullable || right.nullable - def dataType: DataType = BooleanType + override def nullable: Boolean = left.nullable || right.nullable + override def dataType: DataType = BooleanType // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -98,11 +98,11 @@ trait CaseConversionExpression { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "LIKE" + override def symbol: String = "LIKE" // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character - override def escape(v: String) = + override def escape(v: String): String = if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" @@ -129,7 +129,7 @@ case class Like(left: Expression, right: Expression) case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "RLIKE" + override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } @@ -141,7 +141,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: String): String = v.toUpperCase() - override def toString() = s"Upper($child)" + override def toString: String = s"Upper($child)" } /** @@ -151,7 +151,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: String): String = v.toLowerCase() - override def toString() = s"Lower($child)" + override def toString: String = s"Lower($child)" } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -160,7 +160,7 @@ trait StringComparison { type EvaluatedType = Any - def nullable: Boolean = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType def compare(l: String, r: String): Boolean @@ -175,9 +175,9 @@ trait StringComparison { } } - def symbol: String = nodeName + override def symbol: String = nodeName - override def toString() = s"$nodeName($left, $right)" + override def toString: String = s"$nodeName($left, $right)" } /** @@ -185,7 +185,7 @@ trait StringComparison { */ case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - override def compare(l: String, r: String) = l.contains(r) + override def compare(l: String, r: String): Boolean = l.contains(r) } /** @@ -193,7 +193,7 @@ case class Contains(left: Expression, right: Expression) */ case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.startsWith(r) + override def compare(l: String, r: String): Boolean = l.startsWith(r) } /** @@ -201,7 +201,7 @@ case class StartsWith(left: Expression, right: Expression) */ case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.endsWith(r) + override def compare(l: String, r: String): Boolean = l.endsWith(r) } /** @@ -212,17 +212,17 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends type EvaluatedType = Any - override def foldable = str.foldable && pos.foldable && len.foldable + override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - def nullable: Boolean = str.nullable || pos.nullable || len.nullable - def dataType: DataType = { + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") } if (str.dataType == BinaryType) str.dataType else StringType } - override def children = str :: pos :: len :: Nil + override def children: Seq[Expression] = str :: pos :: len :: Nil @inline def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) @@ -267,7 +267,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends } } - override def toString = len match { + override def toString: String = len match { + // TODO: This is broken because max is not an integer value. case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" case _ => s"SUBSTR($str, $pos, $len)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 74edaacc4f609..c23d3b61887c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -141,7 +141,7 @@ object ColumnPruning extends Rule[LogicalPlan] { condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) + def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b4c445b3badf1..9c8c643f7d17a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -91,16 +91,18 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]) = fields.collect { + def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform { - case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + expr.transform { + case a @ Alias(ref: AttributeReference, name) => + aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) - case a: AttributeReference => - aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + case a: AttributeReference => + aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index bd9291e9ba5d7..02f7c26a8ab6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -71,7 +71,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionDown(e: Expression) = { + @inline def transformExpressionDown(e: Expression): Expression = { val newE = e.transformDown(rule) if (newE.fastEquals(e)) { e @@ -104,7 +104,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionUp(e: Expression) = { + @inline def transformExpressionUp(e: Expression): Expression = { val newE = e.transformUp(rule) if (newE.fastEquals(e)) { e @@ -165,5 +165,5 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy */ protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - override def simpleString = statePrefix + super.simpleString + override def simpleString: String = statePrefix + super.simpleString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 384fe53a68362..4d9e41a2b5d85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { - def output = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) override lazy val resolved: Boolean = { val containsAggregatesOrGenerators = projectList.exists ( _.collect { @@ -66,19 +66,19 @@ case class Generate( } } - override def output = + override def output: Seq[Attribute] = if (join) child.output ++ generatorOutput else generatorOutput } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - override def output = left.output + override def output: Seq[Attribute] = left.output - override lazy val resolved = + override lazy val resolved: Boolean = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } @@ -94,7 +94,7 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftSemi => left.output @@ -109,7 +109,7 @@ case class Join( } } - def selfJoinResolved = left.outputSet.intersect(right.outputSet).isEmpty + private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguious expression ids. override lazy val resolved: Boolean = { @@ -118,7 +118,7 @@ case class Join( } case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - def output = left.output + override def output: Seq[Attribute] = left.output } case class InsertIntoTable( @@ -128,10 +128,10 @@ case class InsertIntoTable( overwrite: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = child.output - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { + override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) } @@ -143,14 +143,14 @@ case class CreateTableAsSelect[T]( child: LogicalPlan, allowExisting: Boolean, desc: Option[T] = None) extends UnaryNode { - override def output = Seq.empty[Attribute] - override lazy val resolved = databaseName != None && childrenResolved + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = databaseName != None && childrenResolved } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -163,7 +163,7 @@ case class Sort( order: Seq[SortOrder], global: Boolean, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Aggregate( @@ -172,7 +172,7 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } /** @@ -199,7 +199,7 @@ trait GroupingAnalytics extends UnaryNode { def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] - override def output = aggregations.map(_.toAttribute) + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) } /** @@ -264,7 +264,7 @@ case class Rollup( gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { val limit = limitExpr.eval(null).asInstanceOf[Int] @@ -274,21 +274,21 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output.map(_.withQualifiers(alias :: Nil)) + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Distinct(child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case object NoRelation extends LeafNode { - override def output = Nil + override def output: Seq[Attribute] = Nil /** * Computes [[Statistics]] for this plan. The default implementation assumes the output @@ -301,5 +301,5 @@ case object NoRelation extends LeafNode { } case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output = left.output + override def output: Seq[Attribute] = left.output } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 72b0c5c8e7a26..e737418d9c3bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} /** * Performs a physical redistribution of the data. Used when the consumer of the query @@ -26,14 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} abstract class RedistributeData extends UnaryNode { self: Product => - def output = child.output + override def output: Seq[Attribute] = child.output } case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) - extends RedistributeData { -} + extends RedistributeData case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) - extends RedistributeData { -} - + extends RedistributeData diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 3c3d7a3119064..288c11f69fe22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -72,7 +72,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "a single partition.") // TODO: This is not really valid... - def clustering = ordering.map(_.child).toSet + def clustering: Set[Expression] = ordering.map(_.child).toSet } sealed trait Partitioning { @@ -113,7 +113,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } @@ -124,7 +124,7 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } @@ -139,9 +139,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning { - override def children = expressions - override def nullable = false - override def dataType = IntegerType + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = expressions.toSet @@ -152,7 +152,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case h: HashPartitioning if h == this => true case _ => false @@ -178,9 +178,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) extends Expression with Partitioning { - override def children = ordering - override def nullable = false - override def dataType = IntegerType + override def children: Seq[SortOrder] = ordering + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = ordering.map(_.child).toSet @@ -194,7 +194,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case r: RangePartitioning if r == this => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0ae9f6b2965d4..a2df51e598a2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -36,12 +36,12 @@ object CurrentOrigin { override def initialValue: Origin = Origin() } - def get = value.get() - def set(o: Origin) = value.set(o) + def get: Origin = value.get() + def set(o: Origin): Unit = value.set(o) - def reset() = value.set(Origin()) + def reset(): Unit = value.set(Origin()) - def setPosition(line: Int, start: Int) = { + def setPosition(line: Int, start: Int): Unit = { value.set( value.get.copy(line = Some(line), startPosition = Some(start))) } @@ -57,7 +57,7 @@ object CurrentOrigin { abstract class TreeNode[BaseType <: TreeNode[BaseType]] { self: BaseType with Product => - val origin = CurrentOrigin.get + val origin: Origin = CurrentOrigin.get /** Returns a Seq of the children of this node */ def children: Seq[BaseType] @@ -340,12 +340,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } /** Returns the name of this type of TreeNode. Defaults to the class name. */ - def nodeName = getClass.getSimpleName + def nodeName: String = getClass.getSimpleName /** * The arguments that should be included in the arg string. Defaults to the `productIterator`. */ - protected def stringArgs = productIterator + protected def stringArgs: Iterator[Any] = productIterator /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { @@ -357,18 +357,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { }.mkString(", ") /** String representation of this node without any children */ - def simpleString = s"$nodeName $argString".trim + def simpleString: String = s"$nodeName $argString".trim override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. */ - def numberedTreeString = + def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") /** @@ -420,14 +420,14 @@ trait BinaryNode[BaseType <: TreeNode[BaseType]] { def left: BaseType def right: BaseType - def children = Seq(left, right) + def children: Seq[BaseType] = Seq(left, right) } /** * A [[TreeNode]] with no children. */ trait LeafNode[BaseType <: TreeNode[BaseType]] { - def children = Nil + def children: Seq[BaseType] = Nil } /** @@ -435,6 +435,5 @@ trait LeafNode[BaseType <: TreeNode[BaseType]] { */ trait UnaryNode[BaseType <: TreeNode[BaseType]] { def child: BaseType - def children = child :: Nil + def children: Seq[BaseType] = child :: Nil } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 79a8e06d4b4d4..ea6aa1850db4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -41,11 +41,11 @@ package object trees extends Logging { * A [[TreeNode]] companion for reference equality for Hash based Collection. */ class TreeNodeRef(val obj: TreeNode[_]) { - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case that: TreeNodeRef => that.obj.eq(obj) case _ => false } - override def hashCode = if (obj == null) 0 else obj.hashCode + override def hashCode: Int = if (obj == null) 0 else obj.hashCode } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index feed50f9a2a2d..c86214a2aa944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils package object util { - def fileToString(file: File, encoding: String = "UTF-8") = { + def fileToString(file: File, encoding: String = "UTF-8"): String = { val inStream = new FileInputStream(file) val outStream = new ByteArrayOutputStream try { @@ -45,7 +45,7 @@ package object util { def resourceToString( resource:String, encoding: String = "UTF-8", - classLoader: ClassLoader = Utils.getSparkClassLoader) = { + classLoader: ClassLoader = Utils.getSparkClassLoader): String = { val inStream = classLoader.getResourceAsStream(resource) val outStream = new ByteArrayOutputStream try { @@ -93,7 +93,7 @@ package object util { new String(out.toByteArray) } - def stringOrNull(a: AnyRef) = if (a == null) null else a.toString + def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString def benchmark[A](f: => A): A = { val startTime = System.nanoTime() From 6948ab6f8ba836446b005f2cf1cc4abc944c5053 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 24 Mar 2015 16:26:43 -0700 Subject: [PATCH 018/171] [SPARK-6088] Correct how tasks that get remote results are shown in UI. It would be great to fix this for 1.3. since the fix is surgical and it helps understandability for users. cc shivaram pwendell Author: Kay Ousterhout Closes #4839 from kayousterhout/SPARK-6088 and squashes the following commits: 3ab012c [Kay Ousterhout] Update getting result time incrementally, correctly set GET_RESULT status f346b49 [Kay Ousterhout] Typos 748ea6b [Kay Ousterhout] Fixed build failure 84d617c [Kay Ousterhout] [SPARK-6088] Correct how tasks that get remote results are shown in the UI. --- .../org/apache/spark/scheduler/TaskInfo.scala | 8 +++--- .../org/apache/spark/ui/jobs/StagePage.scala | 25 +++++++++++++------ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 6fa1f2c880f7a..132a9ced77700 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -81,9 +81,11 @@ class TaskInfo( def status: String = { if (running) { - "RUNNING" - } else if (gettingResult) { - "GET RESULT" + if (gettingResult) { + "GET RESULT" + } else { + "RUNNING" + } } else if (failed) { "FAILED" } else if (successful) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index e03442894c5cc..797c9404bc449 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -269,11 +269,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - if (info.gettingResultTime > 0) { - (info.finishTime - info.gettingResultTime).toDouble - } else { - 0.0 - } + getGettingResultTime(info).toDouble } val gettingResultQuantiles = @@ -464,7 +460,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = info.gettingResultTime + val gettingResultTime = getGettingResultTime(info) val maybeAccumulators = info.accumulables val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} @@ -627,6 +623,19 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {errorSummary}{details} } + private def getGettingResultTime(info: TaskInfo): Long = { + if (info.gettingResultTime > 0) { + if (info.finishTime > 0) { + info.finishTime - info.gettingResultTime + } else { + // The task is still fetching the result. + System.currentTimeMillis - info.gettingResultTime + } + } else { + 0L + } + } + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { val totalExecutionTime = if (info.gettingResult) { @@ -638,6 +647,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val executorOverhead = (metrics.executorDeserializeTime + metrics.resultSerializationTime) - math.max(0, totalExecutionTime - metrics.executorRunTime - executorOverhead) + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) } } From d8ccf655f344eed65cdaf5d9252f1b565b8406ca Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 24 Mar 2015 16:29:40 -0700 Subject: [PATCH 019/171] [SPARK-3570] Include time to open files in shuffle write time. Opening shuffle files can be very significant when the disk is contended, especially when using ext3. While writing data to a file can avoid hitting disk (and instead hit the buffer cache), opening a file always involves writing some metadata about the file to disk, so the open time can be a very significant portion of the shuffle write time. In one job I ran recently, the time to write shuffle data to the file was only 4ms for each task, but the time to open the file was about 100x as long (~400ms). When we add metrics about spilled data (#2504), we should ensure that the file open time is also included there. Author: Kay Ousterhout Closes #4550 from kayousterhout/SPARK-3570 and squashes the following commits: ea3a4ae [Kay Ousterhout] Added comment about excluded open time fdc5185 [Kay Ousterhout] Improved comment 42b7e43 [Kay Ousterhout] Fixed parens for nanotime 2423555 [Kay Ousterhout] [SPARK-3570] Include time to open files in shuffle write time. --- .../org/apache/spark/shuffle/FileShuffleBlockManager.scala | 4 ++++ .../org/apache/spark/shuffle/sort/SortShuffleWriter.scala | 3 +++ .../org/apache/spark/util/collection/ExternalSorter.scala | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 660df00bc32f5..d0178dfde6935 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -112,6 +112,7 @@ class FileShuffleBlockManager(conf: SparkConf) private val shuffleState = shuffleStates(shuffleId) private var fileGroup: ShuffleFileGroup = null + val openStartTime = System.nanoTime val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => @@ -135,6 +136,9 @@ class FileShuffleBlockManager(conf: SparkConf) blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) } } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, so should be included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { if (consolidateShuffleFiles) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index fa2e617762f55..55ea0f17b156a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -63,6 +63,9 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter.insertAll(records) } + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3262e670c2030..b962c101c91da 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -352,6 +352,7 @@ private[spark] class ExternalSorter[K, V, C]( // Create our file writers if we haven't done so yet if (partitionWriters == null) { curWriteMetrics = new ShuffleWriteMetrics() + val openStartTime = System.nanoTime partitionWriters = Array.fill(numPartitions) { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -359,6 +360,10 @@ private[spark] class ExternalSorter[K, V, C]( val (blockId, file) = diskBlockManager.createTempShuffleBlock() blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) } // No need to sort stuff, just write each element out From f7c3668ee62d5b87a0fdf64614aa45c7409ff585 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 24 Mar 2015 16:41:31 -0700 Subject: [PATCH 020/171] Revert "[SPARK-5771][UI][hotfix] Change Requested Cores into * if default cores is not set" This reverts commit 12135e90549f957962899487cd5eb95badd8976d. --- .../scala/org/apache/spark/deploy/master/ui/MasterPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 46509e39c0f23..9ab3c8f8fed9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -218,7 +218,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } } - {if (app.requestedCores == Int.MaxValue) "*" else app.requestedCores} + {app.requestedCores} {Utils.megabytesToString(app.desc.memoryPerSlave)} From dd907d1a9df52ffe0a8e1e8dacd837019d11742c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 24 Mar 2015 16:49:27 -0700 Subject: [PATCH 021/171] Revert "[SPARK-5771] Number of Cores in Completed Applications of Standalone Master Web Page always be 0 if sc.stop() is called" This reverts commit dd077abf2e2949fdfec31074b760b587f00efcf2. Conflicts: core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala --- .../spark/deploy/master/ApplicationInfo.scala | 2 +- .../spark/deploy/master/ui/MasterPage.scala | 32 ++++--------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 536aedb6f9fe9..f979ffa16641a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -91,7 +91,7 @@ private[deploy] class ApplicationInfo( } } - private[master] val requestedCores = desc.maxCores.getOrElse(defaultCores) + private val requestedCores = desc.maxCores.getOrElse(defaultCores) private[master] def coresLeft: Int = requestedCores - coresGranted diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 9ab3c8f8fed9c..45412a35e9a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -75,16 +75,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workers = state.workers.sortBy(_.id) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val activeAppHeaders = Seq("Application ID", "Name", "Cores in Use", - "Cores Requested", "Memory per Node", "Submitted Time", "User", "State", "Duration") + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse - val activeAppsTable = UIUtils.listingTable(activeAppHeaders, activeAppRow, activeApps) - - val completedAppHeaders = Seq("Application ID", "Name", "Cores Requested", "Memory per Node", - "Submitted Time", "User", "State", "Duration") + val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) val completedApps = state.completedApps.sortBy(_.endTime).reverse - val completedAppsTable = UIUtils.listingTable(completedAppHeaders, completeAppRow, - completedApps) + val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores", "Memory", "Main Class") @@ -191,7 +187,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } - private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = { + private def appRow(app: ApplicationInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) { val killLinkUri = s"app/kill?id=${app.id}&terminate=true" @@ -201,7 +197,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { (kill) } - {app.id} @@ -210,15 +205,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.desc.name} - { - if (active) { - - {app.coresGranted} - - } - } - {app.requestedCores} + {app.coresGranted} {Utils.megabytesToString(app.desc.memoryPerSlave)} @@ -230,14 +218,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } - private def activeAppRow(app: ApplicationInfo): Seq[Node] = { - appRow(app, active = true) - } - - private def completeAppRow(app: ApplicationInfo): Seq[Node] = { - appRow(app, active = false) - } - private def driverRow(driver: DriverInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (driver.state == DriverState.RUNNING || From 05c2214b41f4c0fd17b6f0c62e26398b963efe64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christophe=20Pr=C3=A9aud?= Date: Tue, 24 Mar 2015 17:05:49 -0700 Subject: [PATCH 022/171] [SPARK-6469] Improving documentation on YARN local directories usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Clarify the local directories usage in YARN Author: Christophe Préaud Closes #5165 from preaudc/yarn-doc-local-dirs and squashes the following commits: 6912b90 [Christophe Préaud] Fix some formatting issues. 4fa8ec2 [Christophe Préaud] Merge remote-tracking branch 'upstream/master' into yarn-doc-local-dirs eaaf519 [Christophe Préaud] Clarify the local directories usage in YARN 436fb7d [Christophe Préaud] Revert "Clarify the local directories usage in YARN" 876ae5e [Christophe Préaud] Clarify the local directories usage in YARN 608dbfa [Christophe Préaud] Merge remote-tracking branch 'upstream/master' a49a2ce [Christophe Préaud] Merge remote-tracking branch 'upstream/master' 9ba89ca [Christophe Préaud] Ensure that files are fetched atomically 54419ae [Christophe Préaud] Merge remote-tracking branch 'upstream/master' c6a5590 [Christophe Préaud] Revert commit 8ea871f8130b2490f1bad7374a819bf56f0ccbbd 7456a33 [Christophe Préaud] Merge remote-tracking branch 'upstream/master' 8ea871f [Christophe Préaud] Ensure that files are fetched atomically --- docs/running-on-yarn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 68b1aeb8ebd01..d9f3eb2b74b18 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -274,6 +274,6 @@ If you need a reference to the proper location to put log files in the YARN so t # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. +- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. From 6930e965e26d39fa6c26ae67a08b4c4d0368d556 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 24 Mar 2015 17:06:22 -0700 Subject: [PATCH 023/171] [SPARK-6512] add contains to OpenHashMap Add `contains` to test whether a key exists in an OpenHashMap. rxin Author: Xiangrui Meng Closes #5171 from mengxr/openhashmap-contains and squashes the following commits: d6e6f1f [Xiangrui Meng] add contains to primitivekeyopenhashmap 748a69b [Xiangrui Meng] add contains to OpenHashMap --- .../org/apache/spark/util/collection/OpenHashMap.scala | 9 +++++++++ .../util/collection/PrimitiveKeyOpenHashMap.scala | 5 +++++ .../spark/util/collection/OpenHashMapSuite.scala | 10 ++++++++++ .../util/collection/PrimitiveKeyOpenHashMapSuite.scala | 7 +++++++ 4 files changed, 31 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index c52591b352340..efc2482c74ddf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -53,6 +53,15 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + /** Tests whether this map contains a binding for a key. */ + def contains(k: K): Boolean = { + if (k == null) { + haveNullValue + } else { + _keySet.getPos(k) != OpenHashSet.INVALID_POS + } + } + /** Get the value for a given key */ def apply(k: K): V = { if (k == null) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index 61e22642761f0..b4ec4ea521253 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -48,6 +48,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, override def size: Int = _keySet.size + /** Tests whether this map contains a binding for a key. */ + def contains(k: K): Boolean = { + _keySet.getPos(k) != OpenHashSet.INVALID_POS + } + /** Get the value for a given key */ def apply(k: K): V = { val pos = _keySet.getPos(k) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 6a70877356409..ef890d2ba60f3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -176,4 +176,14 @@ class OpenHashMapSuite extends FunSuite with Matchers { assert(map(i.toString) === i.toString) } } + + test("contains") { + val map = new OpenHashMap[String, Int](2) + map("a") = 1 + assert(map.contains("a")) + assert(!map.contains("b")) + assert(!map.contains(null)) + map(null) = 0 + assert(map.contains(null)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index 8c7df7d73dcd3..caf378fec8b3e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -118,4 +118,11 @@ class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { assert(map(i.toLong) === i.toString) } } + + test("contains") { + val map = new PrimitiveKeyOpenHashMap[Int, Int](1) + map(0) = 0 + assert(map.contains(0)) + assert(!map.contains(1)) + } } From 94598653bc772e71709163db3fed4048aa7f5f75 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Mar 2015 17:08:25 -0700 Subject: [PATCH 024/171] [SPARK-6428][Streaming] Added explicit types for all public methods. Author: Reynold Xin Closes #5110 from rxin/streaming-explicit-type and squashes the following commits: 2c2db32 [Reynold Xin] [SPARK-6428][Streaming] Added explicit types for all public methods. --- .../apache/spark/streaming/Checkpoint.scala | 6 ++-- .../apache/spark/streaming/DStreamGraph.scala | 6 ++-- .../org/apache/spark/streaming/Duration.scala | 15 ++++----- .../org/apache/spark/streaming/Interval.scala | 10 +++--- .../spark/streaming/StreamingContext.scala | 5 +-- .../streaming/api/java/JavaDStreamLike.scala | 3 +- .../streaming/api/python/PythonDStream.scala | 15 +++++---- .../spark/streaming/dstream/DStream.scala | 20 ++++++----- .../dstream/DStreamCheckpointData.scala | 2 +- .../streaming/dstream/FileInputDStream.scala | 4 +-- .../streaming/dstream/FilteredDStream.scala | 2 +- .../dstream/FlatMapValuedDStream.scala | 2 +- .../streaming/dstream/FlatMappedDStream.scala | 2 +- .../streaming/dstream/ForEachDStream.scala | 2 +- .../streaming/dstream/GlommedDStream.scala | 2 +- .../streaming/dstream/InputDStream.scala | 2 +- .../dstream/MapPartitionedDStream.scala | 2 +- .../streaming/dstream/MapValuedDStream.scala | 2 +- .../streaming/dstream/MappedDStream.scala | 2 +- .../dstream/ReducedWindowedDStream.scala | 4 +-- .../streaming/dstream/ShuffledDStream.scala | 2 +- .../streaming/dstream/StateDStream.scala | 2 +- .../dstream/TransformedDStream.scala | 2 +- .../streaming/dstream/UnionDStream.scala | 6 ++-- .../streaming/dstream/WindowedDStream.scala | 2 +- .../rdd/WriteAheadLogBackedBlockRDD.scala | 6 ++-- .../streaming/receiver/ActorReceiver.scala | 18 +++++----- .../streaming/receiver/BlockGenerator.scala | 2 +- .../spark/streaming/receiver/Receiver.scala | 2 +- .../receiver/ReceiverSupervisor.scala | 4 +-- .../receiver/ReceiverSupervisorImpl.scala | 6 ++-- .../spark/streaming/scheduler/Job.scala | 2 +- .../streaming/scheduler/JobGenerator.scala | 8 ++--- .../streaming/scheduler/JobScheduler.scala | 2 +- .../spark/streaming/scheduler/JobSet.scala | 11 +++---- .../streaming/scheduler/ReceiverTracker.scala | 11 +++---- .../ui/StreamingJobProgressListener.scala | 33 +++++++++++-------- .../spark/streaming/util/RawTextHelper.scala | 14 ++++---- 38 files changed, 127 insertions(+), 114 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index db64e11e16304..f73b463d07779 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -67,12 +67,12 @@ object Checkpoint extends Logging { val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r /** Get the checkpoint file for the given checkpoint time */ - def checkpointFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds) } /** Get the checkpoint backup file for the given checkpoint time */ - def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointBackupFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") } @@ -232,6 +232,8 @@ object CheckpointReader extends Logging { def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) + + // TODO(rxin): Why is this a def?! def fs = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 0e285d6088ec1..175140481e5ae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -100,11 +100,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - def getInputStreams() = this.synchronized { inputStreams.toArray } + def getInputStreams(): Array[InputDStream[_]] = this.synchronized { inputStreams.toArray } - def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getOutputStreams(): Array[DStream[_]] = this.synchronized { outputStreams.toArray } - def getReceiverInputStreams() = this.synchronized { + def getReceiverInputStreams(): Array[ReceiverInputDStream[_]] = this.synchronized { inputStreams.filter(_.isInstanceOf[ReceiverInputDStream[_]]) .map(_.asInstanceOf[ReceiverInputDStream[_]]) .toArray diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala index a0d8fb5ab93ec..3249bb348981f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -55,7 +55,6 @@ case class Duration (private val millis: Long) { def div(that: Duration): Double = this / that - def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) @@ -71,7 +70,7 @@ case class Duration (private val millis: Long) { def milliseconds: Long = millis - def prettyPrint = Utils.msDurationToString(millis) + def prettyPrint: String = Utils.msDurationToString(millis) } @@ -80,7 +79,7 @@ case class Duration (private val millis: Long) { * a given number of milliseconds. */ object Milliseconds { - def apply(milliseconds: Long) = new Duration(milliseconds) + def apply(milliseconds: Long): Duration = new Duration(milliseconds) } /** @@ -88,7 +87,7 @@ object Milliseconds { * a given number of seconds. */ object Seconds { - def apply(seconds: Long) = new Duration(seconds * 1000) + def apply(seconds: Long): Duration = new Duration(seconds * 1000) } /** @@ -96,7 +95,7 @@ object Seconds { * a given number of minutes. */ object Minutes { - def apply(minutes: Long) = new Duration(minutes * 60000) + def apply(minutes: Long): Duration = new Duration(minutes * 60000) } // Java-friendlier versions of the objects above. @@ -107,16 +106,16 @@ object Durations { /** * @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds. */ - def milliseconds(milliseconds: Long) = Milliseconds(milliseconds) + def milliseconds(milliseconds: Long): Duration = Milliseconds(milliseconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of seconds. */ - def seconds(seconds: Long) = Seconds(seconds) + def seconds(seconds: Long): Duration = Seconds(seconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of minutes. */ - def minutes(minutes: Long) = Minutes(minutes) + def minutes(minutes: Long): Duration = Minutes(minutes) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala index ad4f3fdd14ad6..3f5be785e1b1a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala @@ -39,18 +39,18 @@ class Interval(val beginTime: Time, val endTime: Time) { this.endTime < that.endTime } - def <= (that: Interval) = (this < that || this == that) + def <= (that: Interval): Boolean = (this < that || this == that) - def > (that: Interval) = !(this <= that) + def > (that: Interval): Boolean = !(this <= that) - def >= (that: Interval) = !(this < that) + def >= (that: Interval): Boolean = !(this < that) - override def toString = "[" + beginTime + ", " + endTime + "]" + override def toString: String = "[" + beginTime + ", " + endTime + "]" } private[streaming] object Interval { - def currentInterval(duration: Duration): Interval = { + def currentInterval(duration: Duration): Interval = { val time = new Time(System.currentTimeMillis) val intervalBegin = time.floor(duration) new Interval(intervalBegin, intervalBegin + duration) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 543224d4b07bc..f57f295874645 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -188,7 +188,7 @@ class StreamingContext private[streaming] ( /** * Return the associated Spark context */ - def sparkContext = sc + def sparkContext: SparkContext = sc /** * Set each DStreams in this context to remember RDDs it generated in the last given duration. @@ -596,7 +596,8 @@ object StreamingContext extends Logging { @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) + : PairDStreamFunctions[K, V] = { DStream.toPairDStreamFunctions(stream)(kt, vt, ord) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 2eabdd9387913..73030e15c5661 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -415,8 +415,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T implicit val cmv2: ClassTag[V2] = fakeClassTag implicit val cmw: ClassTag[W] = fakeClassTag - def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = + def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = { transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd + } dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 7053f47ec69a2..4c28654ef6413 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -176,11 +176,11 @@ private[python] abstract class PythonDStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -212,7 +212,7 @@ private[python] class PythonTransformed2DStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent, parent2) + override def dependencies: List[DStream[_]] = List(parent, parent2) override def slideDuration: Duration = parent.slideDuration @@ -223,7 +223,7 @@ private[python] class PythonTransformed2DStream( func(Some(rdd1), Some(rdd2), validTime) } - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -260,12 +260,15 @@ private[python] class PythonReducedWindowedDStream( extends PythonDStream(parent, preduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) - override val mustCheckpoint = true - val invReduceFunc = new TransformFunction(pinvReduceFunc) + override val mustCheckpoint: Boolean = true + + val invReduceFunc: TransformFunction = new TransformFunction(pinvReduceFunc) def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index b874f561c12eb..795c5aa6d585b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -104,7 +104,7 @@ abstract class DStream[T: ClassTag] ( private[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ - def context = ssc + def context: StreamingContext = ssc /* Set the creation call site */ private[streaming] val creationSite = DStream.getCreationSite() @@ -619,14 +619,16 @@ abstract class DStream[T: ClassTag] ( * operator, so this DStream will be registered as an output stream and there materialized. */ def print(num: Int) { - def foreachFunc = (rdd: RDD[T], time: Time) => { - val firstNum = rdd.take(num + 1) - println ("-------------------------------------------") - println ("Time: " + time) - println ("-------------------------------------------") - firstNum.take(num).foreach(println) - if (firstNum.size > num) println("...") - println() + def foreachFunc: (RDD[T], Time) => Unit = { + (rdd: RDD[T], time: Time) => { + val firstNum = rdd.take(num + 1) + println("-------------------------------------------") + println("Time: " + time) + println("-------------------------------------------") + firstNum.take(num).foreach(println) + if (firstNum.size > num) println("...") + println() + } } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 0dc72790fbdbd..39fd21342813e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -114,7 +114,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) } } - override def toString() = { + override def toString: String = { "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]" } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 22de8c02e63c8..66d519171fd76 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -298,7 +298,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private[streaming] class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] + private def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] override def update(time: Time) { hadoopFiles.clear() @@ -320,7 +320,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( } } - override def toString() = { + override def toString: String = { "[\n" + hadoopFiles.size + " file sets\n" + hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index c81534ae584ea..fcd5216f101af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -27,7 +27,7 @@ class FilteredDStream[T: ClassTag]( filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 658623455498c..9d09a3baf37ca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -28,7 +28,7 @@ class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( flatMapValueFunc: V => TraversableOnce[U] ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index c7bb2833eabb8..475ea2d2d4f38 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -27,7 +27,7 @@ class FlatMappedDStream[T: ClassTag, U: ClassTag]( flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 1361c30395b57..685a32e1d280d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -28,7 +28,7 @@ class ForEachDStream[T: ClassTag] ( foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index a9bb51f054048..dbb295fe54f71 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -25,7 +25,7 @@ private[streaming] class GlommedDStream[T: ClassTag](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index aa1993f0580a8..e652702e213ef 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -61,7 +61,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) } } - override def dependencies = List() + override def dependencies: List[DStream[_]] = List() override def slideDuration: Duration = { if (ssc == null) throw new Exception("ssc is null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 3d8ee29df1e82..5994bc1e23f2b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -28,7 +28,7 @@ class MapPartitionedDStream[T: ClassTag, U: ClassTag]( preservePartitioning: Boolean ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 7aea1f945d9db..954d2eb4a7b00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -28,7 +28,7 @@ class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( mapValueFunc: V => U ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index 02704a8d1c2e0..fa14b2e897c3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -27,7 +27,7 @@ class MappedDStream[T: ClassTag, U: ClassTag] ( mapFunc: T => U ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index c0a5af0b65cc3..1385ccbf56ee5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -52,7 +52,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Reduce each batch of data using reduceByKey which will be further reduced by window // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + private val reducedStream = parent.reduceByKey(reduceFunc, partitioner) // Persist RDDs to memory by default as these RDDs are going to be reused. super.persist(StorageLevel.MEMORY_ONLY_SER) @@ -60,7 +60,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( def windowDuration: Duration = _windowDuration - override def dependencies = List(reducedStream) + override def dependencies: List[DStream[_]] = List(reducedStream) override def slideDuration: Duration = _slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 880a89bc36895..7757ccac09a58 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -33,7 +33,7 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( mapSideCombine: Boolean = true ) extends DStream[(K,C)] (parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index ebb04dd35b9a2..de8718d0a80fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -36,7 +36,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 71b61856e23c0..5d46ca0715ffd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -32,7 +32,7 @@ class TransformedDStream[U: ClassTag] ( require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index abbc40befa95b..9405dbaa12329 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -33,17 +33,17 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { + parents.map(_.getOrCompute(validTime)).foreach { case Some(rdd) => rdds += rdd case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) + } if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 775b6bfd065c0..899865a906c27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -46,7 +46,7 @@ class WindowedDStream[T: ClassTag]( def windowDuration: Duration = _windowDuration - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = _slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index dd1e96334952f..93caa4ba35c7f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -117,8 +117,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPreferredLocations(split: Partition): Seq[String] = { val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockLocations = getBlockIdLocations().get(partition.blockId) - def segmentLocations = HdfsUtils.getFileSegmentLocations( - partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig) - blockLocations.getOrElse(segmentLocations) + blockLocations.getOrElse( + HdfsUtils.getFileSegmentLocations( + partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index a7d63bd4f2dbf..cd309788a7717 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.receiver +import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.duration._ @@ -25,10 +26,10 @@ import scala.reflect.ClassTag import akka.actor._ import akka.actor.SupervisorStrategy.{Escalate, Restart} + import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.StorageLevel -import java.nio.ByteBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -149,13 +150,13 @@ private[streaming] class ActorReceiver[T: ClassTag]( class Supervisor extends Actor { override val supervisorStrategy = receiverSupervisorStrategy - val worker = context.actorOf(props, name) + private val worker = context.actorOf(props, name) logInfo("Started receiver worker at:" + worker.path) - val n: AtomicInteger = new AtomicInteger(0) - val hiccups: AtomicInteger = new AtomicInteger(0) + private val n: AtomicInteger = new AtomicInteger(0) + private val hiccups: AtomicInteger = new AtomicInteger(0) - def receive = { + override def receive: PartialFunction[Any, Unit] = { case IteratorData(iterator) => logDebug("received iterator") @@ -189,13 +190,12 @@ private[streaming] class ActorReceiver[T: ClassTag]( } } - def onStart() = { + def onStart(): Unit = { supervisor logInfo("Supervision tree for receivers initialized at:" + supervisor.path) - } - def onStop() = { + def onStop(): Unit = { supervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index ee5e639b26d91..42514d8b47dcf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -120,7 +120,7 @@ private[streaming] class BlockGenerator( * `BlockGeneratorListener.onAddData` callback will be called. All received data items * will be periodically pushed into BlockManager. */ - def addDataWithCallback(data: Any, metadata: Any) = synchronized { + def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { waitToPush() currentBuffer += data listener.onAddData(data, metadata) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5acf8a9a811ee..5b5a3fe648602 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -245,7 +245,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * Get the unique identifier the receiver input stream that this * receiver is associated with. */ - def streamId = id + def streamId: Int = id /* * ================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 1f0244c251eba..4943f29395d12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -162,13 +162,13 @@ private[streaming] abstract class ReceiverSupervisor( } /** Check if receiver has been marked for stopping */ - def isReceiverStarted() = { + def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started } /** Check if receiver has been marked for stopping */ - def isReceiverStopped() = { + def isReceiverStopped(): Boolean = { logDebug("state = " + receiverState) receiverState == Stopped } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 7d29ed88cfcb4..8f2f1fef76874 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await -import akka.actor.{Actor, Props} +import akka.actor.{ActorRef, Actor, Props} import akka.pattern.ask import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration @@ -83,7 +83,7 @@ private[streaming] class ReceiverSupervisorImpl( private val actor = env.actorSystem.actorOf( Props(new Actor { - override def receive() = { + override def receive: PartialFunction[Any, Unit] = { case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) @@ -92,7 +92,7 @@ private[streaming] class ReceiverSupervisorImpl( cleanupOldBlocks(threshTime) } - def ref = self + def ref: ActorRef = self }), "Receiver-" + streamId + "-" + System.currentTimeMillis()) /** Unique block ids if one wants to add blocks directly */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 7e0f6b2cdfc08..30cf87f5b7dd1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -36,5 +36,5 @@ class Job(val time: Time, func: () => _) { id = "streaming job " + time + "." + number } - override def toString = id + override def toString: String = id } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 59488dfb0f8c6..4946806d2ee95 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -82,7 +82,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { if (eventActor != null) return // generator has already been started eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case event: JobGeneratorEvent => processEvent(event) } }), "JobGenerator") @@ -111,8 +111,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val pollTime = 100 // To prevent graceful stop to get stuck permanently - def hasTimedOut = { - val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout + def hasTimedOut: Boolean = { + val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeout if (timedOut) { logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")") } @@ -133,7 +133,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Stopped generation timer") // Wait for the jobs to complete and checkpoints to be written - def haveAllBatchesBeenProcessed = { + def haveAllBatchesBeenProcessed: Boolean = { lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime } logInfo("Waiting for jobs to be processed and checkpoints to be written") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 60bc099b27a4c..d6a93acbe711b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -56,7 +56,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { logDebug("Starting JobScheduler") eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case event: JobSchedulerEvent => processEvent(event) } }), "JobScheduler") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 8c15a75b1b0e0..5b134877d0b2d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,8 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty - ) { + receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -48,17 +47,17 @@ case class JobSet( if (hasCompleted) processingEndTime = System.currentTimeMillis() } - def hasStarted = processingStartTime > 0 + def hasStarted: Boolean = processingStartTime > 0 - def hasCompleted = incompleteJobs.isEmpty + def hasCompleted: Boolean = incompleteJobs.isEmpty // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) - def processingDelay = processingEndTime - processingStartTime + def processingDelay: Long = processingEndTime - processingStartTime // Time taken to process all the jobs from the time they were submitted // (i.e. including the time they wait in the streaming scheduler queue) - def totalDelay = { + def totalDelay: Long = { processingEndTime - time.milliseconds } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b36aeb341d25e..98900473138fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -72,7 +72,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ - def start() = synchronized { + def start(): Unit = synchronized { if (actor != null) { throw new SparkException("ReceiverTracker already started") } @@ -86,7 +86,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Stop the receiver execution thread. */ - def stop(graceful: Boolean) = synchronized { + def stop(graceful: Boolean): Unit = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers if (!skipReceiverLaunch) receiverExecutor.stop(graceful) @@ -201,7 +201,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Actor to receive messages from the receivers. */ private class ReceiverTrackerActor extends Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case RegisterReceiver(streamId, typ, host, receiverActor) => registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true @@ -244,16 +244,15 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (graceful) { val pollTime = 100 - def done = { receiverInfo.isEmpty && !running } logInfo("Waiting for receiver job to terminate gracefully") - while(!done) { + while (receiverInfo.nonEmpty || running) { Thread.sleep(pollTime) } logInfo("Waited for receiver job to terminate gracefully") } // Check if all the receivers have been deregistered or not - if (!receiverInfo.isEmpty) { + if (receiverInfo.nonEmpty) { logWarning("Not all of the receivers have deregistered, " + receiverInfo) } else { logInfo("All of the receivers have deregistered successfully") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 5ee53a5c5f561..e4bd067cacb77 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,9 +17,10 @@ package org.apache.spark.streaming.ui +import scala.collection.mutable.{Queue, HashMap} + import org.apache.spark.streaming.{Time, StreamingContext} import org.apache.spark.streaming.scheduler._ -import scala.collection.mutable.{Queue, HashMap} import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted import org.apache.spark.streaming.scheduler.BatchInfo @@ -59,11 +60,13 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) = synchronized { - runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + synchronized { + runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + } } - override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized { + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized { runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo waitingBatchInfos.remove(batchStarted.batchInfo.batchTime) @@ -72,19 +75,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized { - waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) - runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedaBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() - totalCompletedBatches += 1L - - batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => - totalProcessedRecords += infos.map(_.numRecords).sum + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + synchronized { + waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) + runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) + completedaBatchInfos.enqueue(batchCompleted.batchInfo) + if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() + totalCompletedBatches += 1L + + batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => + totalProcessedRecords += infos.map(_.numRecords).sum + } } } - def numReceivers = synchronized { + def numReceivers: Int = synchronized { ssc.graph.getReceiverInputStreams().size } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index a73d6f3bf0661..4d968f8bfa7a8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -18,9 +18,7 @@ package org.apache.spark.streaming.util import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.util.collection.OpenHashMap -import scala.collection.JavaConversions.mapAsScalaMap private[streaming] object RawTextHelper { @@ -71,7 +69,7 @@ object RawTextHelper { var count = 0 while(data.hasNext) { - value = data.next + value = data.next() if (value != null) { count += 1 if (len == 0) { @@ -108,9 +106,13 @@ object RawTextHelper { } } - def add(v1: Long, v2: Long) = (v1 + v2) + def add(v1: Long, v2: Long): Long = { + v1 + v2 + } - def subtract(v1: Long, v2: Long) = (v1 - v2) + def subtract(v1: Long, v2: Long): Long = { + v1 - v2 + } - def max(v1: Long, v2: Long) = math.max(v1, v2) + def max(v1: Long, v2: Long): Long = math.max(v1, v2) } From c14ddd97ed662a8429b9b9078bd7c1a5a1dd3d6c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 24 Mar 2015 18:58:27 -0700 Subject: [PATCH 025/171] [SPARK-6515] update OpenHashSet impl Though I don't see any bug in the existing code, the update in this PR makes it read better. rxin Author: Xiangrui Meng Closes #5176 from mengxr/SPARK-6515 and squashes the following commits: 134494d [Xiangrui Meng] update OpenHashSet impl --- .../spark/util/collection/OpenHashSet.scala | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index c80057f95e0b2..1501111a06655 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -122,7 +122,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def addWithoutResize(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask - var i = 1 + var delta = 1 while (true) { if (!_bitset.get(pos)) { // This is a new key. @@ -134,14 +134,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( // Found an existing key. return pos } else { - val delta = i + // quadratic probing with values increase by 1, 2, 3, ... pos = (pos + delta) & _mask - i += 1 + delta += 1 } } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS + throw new RuntimeException("Should never reach here.") } /** @@ -163,21 +161,19 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def getPos(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask - var i = 1 - val maxProbe = _data.size - while (i < maxProbe) { + var delta = 1 + while (true) { if (!_bitset.get(pos)) { return INVALID_POS } else if (k == _data(pos)) { return pos } else { - val delta = i + // quadratic probing with values increase by 1, 2, 3, ... pos = (pos + delta) & _mask - i += 1 + delta += 1 } } - // Never reached here - INVALID_POS + throw new RuntimeException("Should never reach here.") } /** Return the value at the specified position. */ From c5cc41468e8709d09c09289bb55bc8edc99404b1 Mon Sep 17 00:00:00 2001 From: Bill Chambers Date: Tue, 24 Mar 2015 22:24:35 -0700 Subject: [PATCH 026/171] [DOCUMENTATION]Fixed Missing Type Import in Documentation Needed to import the types specifically, not the more general pyspark.sql Author: Bill Chambers Author: anabranch Closes #5179 from anabranch/master and squashes the following commits: 8fa67bf [anabranch] Corrected SqlContext Import 603b080 [Bill Chambers] [DOCUMENTATION]Fixed Missing Type Import in Documentation --- docs/sql-programming-guide.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6a333fdb562a7..c99a0b03442c4 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -624,7 +624,8 @@ tuples or lists in the RDD created in the step 1. For example: {% highlight python %} # Import SQLContext and data types -from pyspark.sql import * +from pyspark.sql import SQLContext +from pyspark.sql.types import * # sc is an existing SparkContext. sqlContext = SQLContext(sc) From 64262ed99912e780b51f240a14dc98fc3cdf916d Mon Sep 17 00:00:00 2001 From: zzcclp Date: Wed, 25 Mar 2015 19:11:04 +0800 Subject: [PATCH 027/171] [SPARK-6483][SQL]Improve ScalaUdf called performance. As issue [SPARK-6483](https://issues.apache.org/jira/browse/SPARK-6483) description, ScalaUdf is low performance because of calling *asInstanceOf* to convert per record. With this, the performance of ScalaUdf is the same as other case. thank lianhuiwang for telling me how to resolve this problem. Author: zzcclp Closes #5154 from zzcclp/SPARK-6483 and squashes the following commits: 5ac6e09 [zzcclp] Add a newline at the end of source file cc6868e [zzcclp] Fix for fail on unit test. 0a8cdc3 [zzcclp] indention issue b73836a [zzcclp] Access Seq[Expression] element by :: operator, and update the code gen script. 7763848 [zzcclp] rebase from master --- .../sql/catalyst/expressions/ScalaUdf.scala | 1016 +++++++++++------ 1 file changed, 661 insertions(+), 355 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 1fd5ce342b2ce..389dc4f745723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -39,363 +39,669 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) - - s""" - case $x => - function.asInstanceOf[($anys) => Any]( - $evals) - """ + val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) + val evals = (0 to x - 1).map(x => s"ScalaReflection.convertToScala(child$x.eval(input), child$x.dataType)").reduce(_ + ",\n " + _) + + s""" case $x => + val func = function.asInstanceOf[($anys) => Any] + $childs + (input: Row) => { + func( + $evals) + } + """ }.foreach(println) */ - - override def eval(input: Row): Any = { - val result = children.size match { - case 0 => function.asInstanceOf[() => Any]() - case 1 => - function.asInstanceOf[(Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) - - - case 2 => - function.asInstanceOf[(Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) - - - case 3 => - function.asInstanceOf[(Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) - - - case 4 => - function.asInstanceOf[(Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) - - - case 5 => - function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) - - - case 6 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) - - - case 7 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) - - - case 8 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) - - - case 9 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) - - - case 10 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) - - - case 11 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) - - - case 12 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) - - - case 13 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) - - - case 14 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) - - - case 15 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) - - - case 16 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) - - - case 17 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) - - - case 18 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) - - - case 19 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) - - - case 20 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) - - - case 21 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), - ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) - - - case 22 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), - ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), - ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) - - } - // scalastyle:on - - ScalaReflection.convertToCatalyst(result, dataType) + + val f = children.size match { + case 0 => + val func = function.asInstanceOf[() => Any] + (input: Row) => { + func() + } + + case 1 => + val func = function.asInstanceOf[(Any) => Any] + val child0 = children(0) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType)) + } + + case 2 => + val func = function.asInstanceOf[(Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType)) + } + + case 3 => + val func = function.asInstanceOf[(Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType)) + } + + case 4 => + val func = function.asInstanceOf[(Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType)) + } + + case 5 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType)) + } + + case 6 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType)) + } + + case 7 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType)) + } + + case 8 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType)) + } + + case 9 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType)) + } + + case 10 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType)) + } + + case 11 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType)) + } + + case 12 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType)) + } + + case 13 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType)) + } + + case 14 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType)) + } + + case 15 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType)) + } + + case 16 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType)) + } + + case 17 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType)) + } + + case 18 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType)) + } + + case 19 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType)) + } + + case 20 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType)) + } + + case 21 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + val child20 = children(20) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType), + ScalaReflection.convertToScala(child20.eval(input), child20.dataType)) + } + + case 22 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + val child20 = children(20) + val child21 = children(21) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType), + ScalaReflection.convertToScala(child20.eval(input), child20.dataType), + ScalaReflection.convertToScala(child21.eval(input), child21.dataType)) + } } + + // scalastyle:on + + override def eval(input: Row): Any = ScalaReflection.convertToCatalyst(f(input), dataType) + } From 10c78607b2724f5a64b0cdb966e9c5805f23919b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 25 Mar 2015 17:05:56 +0000 Subject: [PATCH 028/171] [SPARK-6496] [MLLIB] GeneralizedLinearAlgorithm.run(input, initialWeights) should initialize numFeatures In GeneralizedLinearAlgorithm ```numFeatures``` is default to -1, we need to update it to correct value when we call run() to train a model. ```LogisticRegressionWithLBFGS.run(input)``` works well, but when we call ```LogisticRegressionWithLBFGS.run(input, initialWeights)``` to train multiclass classification model, it will throw exception due to the numFeatures is not updated. In this PR, we just update numFeatures at the beginning of GeneralizedLinearAlgorithm.run(input, initialWeights) and add test case. Author: Yanbo Liang Closes #5167 from yanboliang/spark-6496 and squashes the following commits: 8131c48 [Yanbo Liang] LogisticRegressionWithLBFGS.run(input, initialWeights) should initialize numFeatures --- .../spark/mllib/regression/GeneralizedLinearAlgorithm.scala | 4 ++++ .../mllib/classification/LogisticRegressionSuite.scala | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 45b9ebb4cc0d6..9fd60ff7a0c79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -211,6 +211,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index aaa81da9e273c..a26c52852c4d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -425,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val model = lr.run(testRDD) + val numFeatures = testRDD.map(_.features.size).first() + val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2)) + val model2 = lr.run(testRDD, initialWeights) + + LogisticRegressionSuite.checkModelsEqual(model, model2) + /** * The following is the instruction to reproduce the model using R's glmnet package. * From 982952f4aebb474823dd886dd2b18f4277bd7c30 Mon Sep 17 00:00:00 2001 From: Augustin Borsu Date: Wed, 25 Mar 2015 10:16:39 -0700 Subject: [PATCH 029/171] [ML][FEATURE] SPARK-5566: RegEx Tokenizer Added a Regex based tokenizer for ml. Currently the regex is fixed but if I could add a regex type paramater to the paramMap, changing the tokenizer regex could be a parameter used in the crossValidation. Also I wonder what would be the best way to add a stop word list. Author: Augustin Borsu Author: Augustin Borsu Author: Augustin Borsu Author: Xiangrui Meng Closes #4504 from aborsu985/master and squashes the following commits: 716d257 [Augustin Borsu] Merge branch 'mengxr-SPARK-5566' cb07021 [Augustin Borsu] Merge branch 'SPARK-5566' of git://github.com/mengxr/spark into mengxr-SPARK-5566 5f09434 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' a164800 [Xiangrui Meng] remove tabs 556aa27 [Xiangrui Meng] Merge branch 'aborsu985-master' into SPARK-5566 9651aec [Xiangrui Meng] update test f96526d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5566 2338da5 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' e88d7b8 [Xiangrui Meng] change pattern to a StringParameter; update tests 148126f [Augustin Borsu] Added return type to public functions 12dddb4 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' daf685e [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 6a85982 [Augustin Borsu] Style corrections 38b95a1 [Augustin Borsu] Added Java unit test for RegexTokenizer b66313f [Augustin Borsu] Modified the pattern Param so it is compiled when given to the Tokenizer e262bac [Augustin Borsu] Added unit tests in scala cd6642e [Augustin Borsu] Changed regex to pattern 132b00b [Augustin Borsu] Changed matching to gaps and removed case folding 201a107 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' cb9c9a7 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' d3ef6d3 [Augustin Borsu] Added doc to RegexTokenizer 9082fc3 [Augustin Borsu] Removed stopwords parameters and updated doc 19f9e53 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' f6a5002 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 7f930bb [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 77ff9ca [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 2e89719 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 196cd7a [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 11ca50f [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 9f8685a [Augustin Borsu] RegexTokenizer 9e07a78 [Augustin Borsu] Merge remote-tracking branch 'upstream/master' 9547e9d [Augustin Borsu] RegEx Tokenizer 01cd26f [Augustin Borsu] RegExTokenizer --- .../apache/spark/ml/feature/Tokenizer.scala | 66 +++++++++++++- .../spark/ml/feature/JavaTokenizerSuite.java | 71 ++++++++++++++++ .../spark/ml/feature/TokenizerSuite.scala | 85 +++++++++++++++++++ 3 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0b1f90daa7d8e..68401e36950bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param} import org.apache.spark.sql.types.{DataType, StringType, ArrayType} /** @@ -39,3 +39,67 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { override protected def outputDataType: DataType = new ArrayType(StringType, false) } + +/** + * :: AlphaComponent :: + * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) + * or using it to split the text (set matching to false). Optional parameters also allow to fold + * the text to lowercase prior to it being tokenized and to filer tokens using a minimal length. + * It returns an array of strings that can be empty. + * The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true, + * lowercase = false, minTokenLength = 1 + */ +@AlphaComponent +class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + + /** + * param for minimum token length, default is one to avoid returning empty strings + * @group param + */ + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1)) + + /** @group setParam */ + def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) + + /** @group getParam */ + def getMinTokenLength: Int = get(minTokenLength) + + /** + * param sets regex as splitting on gaps (true) or matching tokens (false) + * @group param + */ + val gaps: BooleanParam = new BooleanParam( + this, "gaps", "Set regex to match gaps or tokens", Some(false)) + + /** @group setParam */ + def setGaps(value: Boolean): this.type = set(gaps, value) + + /** @group getParam */ + def getGaps: Boolean = get(gaps) + + /** + * param sets regex pattern used by tokenizer + * @group param + */ + val pattern: Param[String] = new Param( + this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+")) + + /** @group setParam */ + def setPattern(value: String): this.type = set(pattern, value) + + /** @group getParam */ + def getPattern: String = get(pattern) + + override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => + val re = paramMap(pattern).r + val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq + val minLength = paramMap(minTokenLength) + tokens.filter(_.length >= minLength) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java new file mode 100644 index 0000000000000..3806f650025b2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaTokenizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void regexTokenizer() { + RegexTokenizer myRegExTokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(3); + + JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + )); + DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + + Row[] pairs = myRegExTokenizer.transform(dataset) + .select("tokens", "wantedTokens") + .collect(); + + for (Row r : pairs) { + Assert.assertEquals(r.get(0), r.get(1)); + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala new file mode 100644 index 0000000000000..bf862b912d326 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +@BeanInfo +case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) { + /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */ + def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq) +} + +class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("RegexTokenizer") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + + val dataset0 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct")) + )) + testRegexTokenizer(tokenizer, dataset0) + + val dataset1 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Seq("punct")) + )) + + tokenizer.setMinTokenLength(3) + testRegexTokenizer(tokenizer, dataset1) + + tokenizer + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(0) + val dataset2 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct")) + )) + testRegexTokenizer(tokenizer, dataset2) + } +} + +object RegexTokenizerSuite extends FunSuite { + + def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("tokens", "wantedTokens") + .collect() + .foreach { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } +} From 968408b345a0e26f7ee9105a6a0c3456cf10576a Mon Sep 17 00:00:00 2001 From: DoingDone9 <799203320@qq.com> Date: Wed, 25 Mar 2015 11:11:52 -0700 Subject: [PATCH 030/171] [SPARK-6409][SQL] It is not necessary that avoid old inteface of hive, because this will make some UDAF can not work. spark avoid old inteface of hive, then some udaf can not work like "org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage" Author: DoingDone9 <799203320@qq.com> Closes #5131 from DoingDone9/udaf and squashes the following commits: 9de08d0 [DoingDone9] Update HiveUdfSuite.scala 49c62dc [DoingDone9] Update hiveUdfs.scala 98b134f [DoingDone9] Merge pull request #5 from apache/master 161cae3 [DoingDone9] Merge pull request #4 from apache/master c87e8b6 [DoingDone9] Merge pull request #3 from apache/master cb1852d [DoingDone9] Merge pull request #2 from apache/master c3f046f [DoingDone9] Merge pull request #1 from apache/master --- .../scala/org/apache/spark/sql/hive/hiveUdfs.scala | 3 +-- .../spark/sql/hive/execution/HiveUdfSuite.scala | 11 ++++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index bfe43373d9534..47305571e579e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -375,9 +375,8 @@ private[hive] case class HiveUdafFunction( private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - // Cast required to avoid type inference selecting a deprecated Hive API. private val buffer = - function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] + function.getNewAggregationBuffer override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index cb405f56bf53d..d7c5d1a25a82b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -22,7 +22,7 @@ import java.util import java.util.Properties import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} @@ -93,6 +93,15 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + test("SPARK-6409 UDAFAverage test") { + sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer( + sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + Seq(Row(1.0, 260.182))) + sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + TestHive.reset() + } + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) From 883b7e9030e1a3948acee17608e51dcd9f4d55e1 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 25 Mar 2015 12:09:30 -0700 Subject: [PATCH 031/171] [SPARK-6076][Block Manager] Fix a potential OOM issue when StorageLevel is MEMORY_AND_DISK_SER In https://github.com/apache/spark/blob/dcd1e42d6b6ac08d2c0736bf61a15f515a1f222b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala#L538 , when StorageLevel is `MEMORY_AND_DISK_SER`, it will copy the content from file into memory, then put it into MemoryStore. ```scala val copyForMemory = ByteBuffer.allocate(bytes.limit) copyForMemory.put(bytes) memoryStore.putBytes(blockId, copyForMemory, level) bytes.rewind() ``` However, if the file is bigger than the free memory, OOM will happen. A better approach is testing if there is enough memory. If not, copyForMemory should not be created, since this is an optional operation. Author: zsxwing Closes #4827 from zsxwing/SPARK-6076 and squashes the following commits: 7d25545 [zsxwing] Add alias for tryToPut and dropFromMemory 1100a54 [zsxwing] Replace call-by-name with () => T 0cc0257 [zsxwing] Fix a potential OOM issue when StorageLevel is MEMORY_AND_DISK_SER --- .../apache/spark/storage/BlockManager.scala | 23 +++++++--- .../apache/spark/storage/MemoryStore.scala | 43 ++++++++++++++++--- .../spark/storage/BlockManagerSuite.scala | 34 +++++++++++++-- 3 files changed, 85 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 80d66e59132da..1dff09a75d038 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -535,9 +535,14 @@ private[spark] class BlockManager( /* We'll store the bytes in memory if the block's storage level includes * "memory serialized", or if it should be cached as objects in memory * but we only requested its serialized bytes. */ - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) + memoryStore.putBytes(blockId, bytes.limit, () => { + // https://issues.apache.org/jira/browse/SPARK-6076 + // If the file size is bigger than the free memory, OOM will happen. So if we cannot + // put it into MemoryStore, copyForMemory should not be created. That's why this + // action is put into a `() => ByteBuffer` and created lazily. + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + }) bytes.rewind() } if (!asBlockResult) { @@ -991,15 +996,23 @@ private[spark] class BlockManager( putIterator(blockId, Iterator(value), level, tellMaster) } + def dropFromMemory( + blockId: BlockId, + data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + dropFromMemory(blockId, () => data) + } + /** * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. * + * If `data` is not put on disk, it won't be created. + * * Return the block status if the given block has been updated, else None. */ def dropFromMemory( blockId: BlockId, - data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull @@ -1023,7 +1036,7 @@ private[spark] class BlockManager( // Drop to disk, if storage level requires if (level.useDisk && !diskStore.contains(blockId)) { logInfo(s"Writing block $blockId to disk") - data match { + data() match { case Left(elements) => diskStore.putArray(blockId, elements, level, returnValues = false) case Right(bytes) => diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 1be860aea63d0..ed609772e6979 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -98,6 +98,26 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and + * put it into MemoryStore. Otherwise, the ByteBuffer won't be created. + * + * The caller should guarantee that `size` is correct. + */ + def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { + // Work on a duplicate - since the original input might be used elsewhere. + lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] + val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) + val data = + if (putAttempt.success) { + assert(bytes.limit == size) + Right(bytes.duplicate()) + } else { + null + } + PutResult(size, data, putAttempt.droppedBlocks) + } + override def putArray( blockId: BlockId, values: Array[Any], @@ -312,11 +332,22 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) blockId.asRDDId.map(_.rddId) } + private def tryToPut( + blockId: BlockId, + value: Any, + size: Long, + deserialized: Boolean): ResultWithDroppedBlocks = { + tryToPut(blockId, () => value, size, deserialized) + } + /** * Try to put in a set of values, if we can free up enough space. The value should either be * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size * must also be passed by the caller. * + * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be + * created to avoid OOM since it may be a big ByteBuffer. + * * Synchronize on `accountingLock` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for @@ -326,7 +357,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) */ private def tryToPut( blockId: BlockId, - value: Any, + value: () => Any, size: Long, deserialized: Boolean): ResultWithDroppedBlocks = { @@ -345,7 +376,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlocks ++= freeSpaceResult.droppedBlocks if (enoughFreeSpace) { - val entry = new MemoryEntry(value, size, deserialized) + val entry = new MemoryEntry(value(), size, deserialized) entries.synchronized { entries.put(blockId, entry) currentMemory += size @@ -357,12 +388,12 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. - val data = if (deserialized) { - Left(value.asInstanceOf[Array[Any]]) + lazy val data = if (deserialized) { + Left(value().asInstanceOf[Array[Any]]) } else { - Right(value.asInstanceOf[ByteBuffer].duplicate()) + Right(value().asInstanceOf[ByteBuffer].duplicate()) } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) + val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 3fdbe99b5d02b..ecd1cba5b5abe 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -170,8 +170,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) + store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") @@ -413,8 +413,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach t2.join() t3.join() - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) + store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) store.waitForAsyncReregister() } } @@ -1223,4 +1223,30 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) } + + test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + val blockId = BlockId("rdd_3_10") + val result = memoryStore.putBytes(blockId, 13000, () => { + fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") + }) + assert(result.size === 13000) + assert(result.data === null) + assert(result.droppedBlocks === Nil) + } + + test("put a small ByteBuffer to MemoryStore") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + val blockId = BlockId("rdd_3_10") + var bytes: ByteBuffer = null + val result = memoryStore.putBytes(blockId, 10000, () => { + bytes = ByteBuffer.allocate(10000) + bytes + }) + assert(result.size === 10000) + assert(result.data === Right(bytes)) + assert(result.droppedBlocks === Nil) + } } From acef51defb991bcdc99b76cf2a126afd6d60ec70 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 25 Mar 2015 13:27:15 -0700 Subject: [PATCH 032/171] [SPARK-6537] UIWorkloadGenerator: The main thread should not stop SparkContext until all jobs finish The main thread of UIWorkloadGenerator spawn sub threads to launch jobs but the main thread stop SparkContext without waiting for finishing those threads. Author: Kousuke Saruta Closes #5187 from sarutak/SPARK-6537 and squashes the following commits: 4e9307a [Kousuke Saruta] Fixed UIWorkloadGenerator so that the main thread stop SparkContext after all jobs finish --- .../scala/org/apache/spark/ui/UIWorkloadGenerator.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index 19ac7a826e306..5fbcd6bb8ad94 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui +import java.util.concurrent.Semaphore + import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} @@ -88,6 +90,8 @@ private[spark] object UIWorkloadGenerator { ("Job with delays", baseData.map(x => Thread.sleep(100)).count) ) + val barrier = new Semaphore(-nJobSet * jobs.size + 1) + (1 to nJobSet).foreach { _ => for ((desc, job) <- jobs) { new Thread { @@ -99,12 +103,17 @@ private[spark] object UIWorkloadGenerator { } catch { case e: Exception => println("Job Failed: " + desc) + } finally { + barrier.release() } } }.start Thread.sleep(INTER_JOB_WAIT_MS) } } + + // Waiting for threads. + barrier.acquire() sc.stop() } } From c1b74df6042b33b2b061cb07c2fbd82dba9074bb Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 25 Mar 2015 13:28:32 -0700 Subject: [PATCH 033/171] [SPARK-5771] Master UI inconsistently displays application cores If the user calls `sc.stop()`, then the number of cores under "Completed Applications" will be 0. If the user does not call `sc.stop()`, then the number of cores will be however many cores were being used before the application exited. This PR makes both cases have the behavior of the latter. Note that there have been a series of PR that attempted to fix this. For the full discussion, please refer to #4841. The unregister event is necessary because of a subtle race condition explained in that PR. Tested this locally with and without calling `sc.stop()`. Author: Andrew Or Closes #5177 from andrewor14/master-ui-cores and squashes the following commits: 62449d1 [Andrew Or] Freeze application state before finishing it --- .../scala/org/apache/spark/deploy/DeployMessage.scala | 2 ++ .../org/apache/spark/deploy/client/AppClient.scala | 1 + .../apache/spark/deploy/master/ApplicationInfo.scala | 4 ++++ .../scala/org/apache/spark/deploy/master/Master.scala | 10 +++++++++- 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 0997507d016f5..9db6fd1ac4dbe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -101,6 +101,8 @@ private[deploy] object DeployMessages { case class RegisterApplication(appDescription: ApplicationDescription) extends DeployMessage + case class UnregisterApplication(appId: String) + case class MasterChangeAcknowledged(appId: String) // Master to AppClient diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 3b729725257ef..4f06d7f96c46e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -157,6 +157,7 @@ private[spark] class AppClient( case StopAppClient => markDead("Application has been stopped.") + master ! UnregisterApplication(appId) sender ! true context.stop(self) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index f979ffa16641a..bc5b293379f2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -111,6 +111,10 @@ private[deploy] class ApplicationInfo( endTime = System.currentTimeMillis() } + private[master] def isFinished: Boolean = { + state != ApplicationState.WAITING && state != ApplicationState.RUNNING + } + def duration: Long = { if (endTime != -1) { endTime - startTime diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 80506621f4d24..9a5d5877da86d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -339,7 +339,11 @@ private[master] class Master( if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") - appInfo.removeExecutor(exec) + // If an application has already finished, preserve its + // state to display its information properly on the UI + if (!appInfo.isFinished) { + appInfo.removeExecutor(exec) + } exec.worker.removeExecutor(exec) val normalExit = exitStatus == Some(0) @@ -428,6 +432,10 @@ private[master] class Master( if (canCompleteRecovery) { completeRecovery() } } + case UnregisterApplication(applicationId) => + logInfo(s"Received unregister request from application $applicationId") + idToApp.get(applicationId).foreach(finishApplication) + case DisassociatedEvent(_, address, _) => { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") From 435337381f093f95248c8f0204e60c0b366edc81 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 25 Mar 2015 13:38:33 -0700 Subject: [PATCH 034/171] [SPARK-6256] [MLlib] MLlib Python API parity check for regression MLlib Python API parity check for Regression, major disparities need to be added for Python list following: ```scala LinearRegressionWithSGD setValidateData LassoWithSGD setIntercept setValidateData RidgeRegressionWithSGD setIntercept setValidateData ``` setFeatureScaling is mllib private function which is not needed to expose in pyspark. Author: Yanbo Liang Closes #4997 from yanboliang/spark-6256 and squashes the following commits: 102f498 [Yanbo Liang] fix intercept issue & add doc test 1fb7b4f [Yanbo Liang] change 'intercept' to 'addIntercept' de5ecbc [Yanbo Liang] MLlib Python API parity check for regression --- .../mllib/api/python/PythonMLLibAPI.scala | 16 +++++-- python/pyspark/mllib/regression.py | 43 ++++++++++++++++--- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 15ca2547d56a8..e39156734794c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -111,9 +111,11 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) + .setValidateData(validateData) lrAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -135,8 +137,12 @@ private[python] class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lassoAlg = new LassoWithSGD() + lassoAlg.setIntercept(intercept) + .setValidateData(validateData) lassoAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -157,8 +163,12 @@ private[python] class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() + ridgeAlg.setIntercept(intercept) + .setValidateData(validateData) ridgeAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 414a0ada80787..209f1ee473b5b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -140,6 +140,13 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... miniBatchFraction=1.0, initialWeights=array([1.0]), regParam=0.1, regType="l2", + ... intercept=True, validateData=True) + >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( @@ -173,7 +180,8 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.0, regType=None, intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False, + validateData=True): """ Train a linear regression model on the given data. @@ -195,15 +203,18 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: None) - @param intercept: Boolean parameter which indicates the use + :param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). (default: False) + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept)) + regType, bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) @@ -253,6 +264,13 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( @@ -273,11 +291,13 @@ class LassoWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): """Train a Lasso regression model on the given data.""" def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -327,6 +347,13 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( @@ -347,11 +374,13 @@ class RidgeRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): """Train a ridge regression model on the given data.""" def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) From 4fc4d0369e8240defe0ee83252426402f1a28a36 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 25 Mar 2015 14:45:23 -0700 Subject: [PATCH 035/171] [SPARK-5987] [MLlib] Save/load for GaussianMixtureModels Should be self explanatory. Author: MechCoder Closes #4986 from MechCoder/spark-5987 and squashes the following commits: 7d2cd56 [MechCoder] Iterate over dataframe in a better way e7a14cb [MechCoder] Minor 33c84f9 [MechCoder] Store as Array[Data] instead of Data[Array] 505bd57 [MechCoder] Rebased over master and used MatrixUDT 7422bb4 [MechCoder] Store sigmas as Array[Double] instead of Array[Array[Double]] b9794e4 [MechCoder] Minor cb77095 [MechCoder] [SPARK-5987] Save/load for GaussianMixtureModels --- docs/mllib-clustering.md | 8 ++ .../clustering/GaussianMixtureModel.scala | 96 ++++++++++++++++++- .../clustering/GaussianMixtureSuite.scala | 52 +++++++--- 3 files changed, 136 insertions(+), 20 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 0b6db4fcb7b1f..f5aa15b7d9b79 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -173,6 +173,7 @@ to the algorithm. We then output the parameters of the mixture model. {% highlight scala %} import org.apache.spark.mllib.clustering.GaussianMixture +import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -182,6 +183,10 @@ val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) // Cluster the data into two classes using GaussianMixture val gmm = new GaussianMixture().setK(2).run(parsedData) +// Save and load model +gmm.save(sc, "myGMMModel") +val sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + // output parameters of max-likelihood model for (i <- 0 until gmm.k) { println("weight=%f\nmu=%s\nsigma=\n%s\n" format @@ -231,6 +236,9 @@ public class GaussianMixtureExample { // Cluster the data into two classes using GaussianMixture GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); + // Save and load GaussianMixtureModel + gmm.save(sc, "myGMMModel") + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") // Output the parameters of the mixture model for(int j=0; j BreezeVector} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: @@ -41,10 +47,16 @@ import org.apache.spark.rdd.RDD @Experimental class GaussianMixtureModel( val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable { + val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") - + + override protected def formatVersion = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) + } + /** Number of gaussians in mixture */ def k: Int = weights.length @@ -83,5 +95,79 @@ class GaussianMixtureModel( p(i) /= pSum } p - } + } +} + +@Experimental +object GaussianMixtureModel extends Loader[GaussianMixtureModel] { + + private object SaveLoadV1_0 { + + case class Data(weight: Double, mu: Vector, sigma: Matrix) + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel" + + def save( + sc: SparkContext, + path: String, + weights: Array[Double], + gaussians: Array[MultivariateGaussian]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(weights.length) { i => + Data(weights(i), gaussians(i).mu, gaussians(i).sigma) + } + sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): GaussianMixtureModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.parquetFile(dataPath) + val dataArray = dataFrame.select("weight", "mu", "sigma").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val (weights, gaussians) = dataArray.map { + case Row(weight: Double, mu: Vector, sigma: Matrix) => + (weight, new MultivariateGaussian(mu, sigma)) + }.unzip + + return new GaussianMixtureModel(weights.toArray, gaussians.toArray) + } + } + + override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val k = (metadata \ "k").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, version) match { + case (classNameV1_0, "1.0") => { + val model = SaveLoadV1_0.load(sc, path) + require(model.weights.length == k, + s"GaussianMixtureModel requires weights of length $k " + + s"got weights of length ${model.weights.length}") + require(model.gaussians.length == k, + s"GaussianMixtureModel requires gaussians of length $k" + + s"got gaussians of length ${model.gaussians.length}") + model + } + case _ => throw new Exception( + s"GaussianMixtureModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index 1b46a4012d731..f356ffa3e3a26 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { test("single cluster") { @@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } test("two clusters") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) + val data = sc.parallelize(GaussianTestData.data) // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( @@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } test("two clusters with sparse data") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) - + val data = sc.parallelize(GaussianTestData.data) val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray)) // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( @@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } + + test("model save / load") { + val data = sc.parallelize(GaussianTestData.data) + + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + gmm.save(sc, path) + + // TODO: GaussianMixtureModel should implement equals/hashcode directly. + val sameModel = GaussianMixtureModel.load(sc, path) + assert(sameModel.k === gmm.k) + (0 until sameModel.k).foreach { i => + assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu) + assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + + object GaussianTestData { + + val data = Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + ) + + } } From d44a3362ed8cf3068f8ff233e13851a39da42219 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 25 Mar 2015 17:40:00 -0700 Subject: [PATCH 036/171] [SPARK-6079] Use index to speed up StatusTracker.getJobIdsForGroup() `StatusTracker.getJobIdsForGroup()` is implemented via a linear scan over a HashMap rather than using an index, which might be an expensive operation if there are many (e.g. thousands) of retained jobs. This patch adds a new map to `JobProgressListener` in order to speed up these lookups. Author: Josh Rosen Closes #4830 from JoshRosen/statustracker-job-group-indexing and squashes the following commits: e39c5c7 [Josh Rosen] Address review feedback 6709fb2 [Josh Rosen] Merge remote-tracking branch 'origin/master' into statustracker-job-group-indexing 2c49614 [Josh Rosen] getOrElse 97275a7 [Josh Rosen] Add jobGroup to jobId index to JobProgressListener --- .../org/apache/spark/SparkStatusTracker.scala | 3 +- .../spark/ui/jobs/JobProgressListener.scala | 23 ++++++++++++-- .../ui/jobs/JobProgressListenerSuite.scala | 31 +++++++++++++++++-- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index edbdda8a0bcb6..34ee3a48f8e74 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -45,8 +45,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getJobIdsForGroup(jobGroup: String): Array[Int] = { jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 949e80d30f5eb..625596885faa1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -44,6 +44,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // These type aliases are public because they're used in the types of public fields: type JobId = Int + type JobGroupId = String type StageId = Int type StageAttemptId = Int type PoolName = String @@ -54,6 +55,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val completedJobs = ListBuffer[JobUIData]() val failedJobs = ListBuffer[JobUIData]() val jobIdToData = new HashMap[JobId, JobUIData] + val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]] // Stages: val pendingStages = new HashMap[StageId, StageInfo] @@ -119,7 +121,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { Map( "jobIdToData" -> jobIdToData.size, "stageIdToData" -> stageIdToData.size, - "stageIdToStageInfo" -> stageIdToInfo.size + "stageIdToStageInfo" -> stageIdToInfo.size, + "jobGroupToJobIds" -> jobGroupToJobIds.values.map(_.size).sum, + // Since jobGroupToJobIds is map of sets, check that we don't leak keys with empty values: + "jobGroupToJobIds keySet" -> jobGroupToJobIds.keys.size ) } @@ -140,7 +145,19 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { if (jobs.size > retainedJobs) { val toRemove = math.max(retainedJobs / 10, 1) jobs.take(toRemove).foreach { job => - jobIdToData.remove(job.jobId) + // Remove the job's UI data, if it exists + jobIdToData.remove(job.jobId).foreach { removedJob => + // A null jobGroupId is used for jobs that are run without a job group + val jobGroupId = removedJob.jobGroup.orNull + // Remove the job group -> job mapping entry, if it exists + jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup => + jobsInGroup.remove(job.jobId) + // If this was the last job in this job group, remove the map entry for the job group + if (jobsInGroup.isEmpty) { + jobGroupToJobIds.remove(jobGroupId) + } + } + } } jobs.trimStart(toRemove) } @@ -158,6 +175,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageIds = jobStart.stageIds, jobGroup = jobGroup, status = JobExecutionStatus.RUNNING) + // A null jobGroupId is used for jobs that are run without a job group + jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId) jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) // Compute (a potential underestimate of) the number of tasks that will be run by this job. // This may be an underestimate because the job start event references all of the result diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 730a4b54f5aa1..c0c28cb60e21d 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import java.util.Properties + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -44,11 +46,19 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc SparkListenerStageCompleted(stageInfo) } - private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = { + private def createJobStartEvent( + jobId: Int, + stageIds: Seq[Int], + jobGroup: Option[String] = None): SparkListenerJobStart = { val stageInfos = stageIds.map { stageId => new StageInfo(stageId, 0, stageId.toString, 0, null, "") } - SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos) + val properties: Option[Properties] = jobGroup.map { groupId => + val props = new Properties() + props.setProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + props + } + SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos, properties.orNull) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { @@ -110,6 +120,23 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc listener.stageIdToActiveJobIds.size should be (0) } + test("test clearing of jobGroupToJobIds") { + val conf = new SparkConf() + conf.set("spark.ui.retainedJobs", 5.toString) + val listener = new JobProgressListener(conf) + + // Run 50 jobs, each with one stage + for (jobId <- 0 to 50) { + listener.onJobStart(createJobStartEvent(jobId, Seq(0), jobGroup = Some(jobId.toString))) + listener.onStageSubmitted(createStageStartEvent(0)) + listener.onStageCompleted(createStageEndEvent(0, failed = false)) + listener.onJobEnd(createJobEndEvent(jobId, false)) + } + assertActiveJobsStateIsEmpty(listener) + // This collection won't become empty, but it should be bounded by spark.ui.retainedJobs + listener.jobGroupToJobIds.size should be (5) + } + test("test LRU eviction of jobs") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) From 8c3b0052f4792d97d23244ade335676e37cb1fae Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 25 Mar 2015 17:40:19 -0700 Subject: [PATCH 037/171] [SPARK-6450] [SQL] Fixes metastore Parquet table conversion The `ParquetConversions` analysis rule generates a hash map, which maps from the original `MetastoreRelation` instances to the newly created `ParquetRelation2` instances. However, `MetastoreRelation.equals` doesn't compare output attributes. Thus, if a single metastore Parquet table appears multiple times in a query, only a single entry ends up in the hash map, and the conversion is not correctly performed. Proper fix for this issue should be overriding `equals` and `hashCode` for MetastoreRelation. Unfortunately, this breaks more tests than expected. It's possible that these tests are ill-formed from the very beginning. As 1.3.1 release is approaching, we'd like to make the change more surgical to avoid potential regressions. The proposed fix here is to make both the metastore relations and their output attributes as keys in the hash map used in ParquetConversions. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5183) Author: Cheng Lian Closes #5183 from liancheng/spark-6450 and squashes the following commits: 3536780 [Cheng Lian] Fixes metastore Parquet table conversion --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 34 ++++++++++--------- .../apache/spark/sql/hive/parquetSuites.scala | 25 ++++++++++++++ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4c5eb48661f7d..d1a99555e90c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -459,7 +459,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) // Write path case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _) @@ -470,7 +470,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) // Read path case p @ PhysicalOperation(_, _, relation: MetastoreRelation) @@ -479,33 +479,35 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) } + // Quick fix for SPARK-6450: Notice that we're using both the MetastoreRelation instances and + // their output attributes as the key of the map. This is because MetastoreRelation.equals + // doesn't take output attributes into account, thus multiple MetastoreRelation instances + // pointing to the same table get collapsed into a single entry in the map. A proper fix for + // this should be overriding equals & hashCode in MetastoreRelation. val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes // attribute IDs referenced in other nodes. plan.transformUp { - case r: MetastoreRelation if relationMap.contains(r) => { - val parquetRelation = relationMap(r) - val withAlias = - r.alias.map(a => Subquery(a, parquetRelation)).getOrElse( - Subquery(r.tableName, parquetRelation)) + case r: MetastoreRelation if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) + val alias = r.alias.getOrElse(r.tableName) + Subquery(alias, parquetRelation) - withAlias - } case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r) => { - val parquetRelation = relationMap(r) + if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) InsertIntoTable(parquetRelation, partition, child, overwrite) - } + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r) => { - val parquetRelation = relationMap(r) + if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) InsertIntoTable(parquetRelation, partition, child, overwrite) - } + case other => other.transformExpressions { case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 8a31bd03092d1..432d65a874518 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -365,6 +365,31 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE IF EXISTS test_insert_parquet") } + + test("SPARK-6450 regression test") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed + + assertResult(2) { + analyzed.collect { + case r @ LogicalRelation(_: ParquetRelation2) => r + }.size + } + + sql("DROP TABLE ms_convert") + } } class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { From e6d1406abd55bc24477eb8c6ee72c31e7110435e Mon Sep 17 00:00:00 2001 From: jeanlyn Date: Wed, 25 Mar 2015 17:47:45 -0700 Subject: [PATCH 038/171] [SPARK-5498][SQL]fix query exception when partition schema does not match table schema In hive,the schema of partition may be difference from the table schema.When we use spark-sql to query the data of partition which schema is difference from the table schema,we will get the exceptions as the description of the [jira](https://issues.apache.org/jira/browse/SPARK-5498) .For example: * We take a look of the schema for the partition and the table ```sql DESCRIBE partition_test PARTITION (dt='1'); id int None name string None dt string None # Partition Information # col_name data_type comment dt string None ``` ``` DESCRIBE partition_test; OK id bigint None name string None dt string None # Partition Information # col_name data_type comment dt string None ``` * run the sql ```sql SELECT * FROM partition_test where dt='1'; ``` we will get the cast exception `java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.MutableLong cannot be cast to org.apache.spark.sql.catalyst.expressions.MutableInt` Author: jeanlyn Closes #4289 from jeanlyn/schema and squashes the following commits: 9c8da74 [jeanlyn] fix style b41d6b9 [jeanlyn] fix compile errors 07d84b6 [jeanlyn] Merge branch 'master' into schema 535b0b6 [jeanlyn] reduce conflicts d6c93c5 [jeanlyn] fix bug 1e8b30c [jeanlyn] fix code style 0549759 [jeanlyn] fix code style c879aa1 [jeanlyn] clean the code 2a91a87 [jeanlyn] add more test case and clean the code 12d800d [jeanlyn] fix code style 63d170a [jeanlyn] fix compile problem 7470901 [jeanlyn] reduce conflicts afc7da5 [jeanlyn] make getConvertedOI compatible between 0.12.0 and 0.13.1 b1527d5 [jeanlyn] fix type mismatch 10744ca [jeanlyn] Insert a space after the start of the comment 3b27af3 [jeanlyn] SPARK-5498:fix bug when query the data when partition schema does not match table schema --- .../apache/spark/sql/hive/TableReader.scala | 37 +++++++++++----- .../sql/hive/InsertIntoHiveTableSuite.scala | 42 +++++++++++++++++++ .../org/apache/spark/sql/hive/Shim12.scala | 10 ++++- .../org/apache/spark/sql/hive/Shim13.scala | 9 +++- 4 files changed, 84 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index af309c0c6ce2c..3563472c7ae81 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} @@ -116,7 +116,7 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } deserializedHadoopRDD @@ -189,9 +189,13 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) + // get the table deserializer + val tableSerDe = tableDesc.getDeserializerClass.newInstance() + tableSerDe.initialize(hconf, tableDesc.getProperties) // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, + mutableRow, tableSerDe) } }.toSeq @@ -261,25 +265,36 @@ private[hive] object HadoopTableReader extends HiveInspectors { * Transform all given raw `Writable`s into `Row`s. * * @param iterator Iterator of all `Writable`s to be transformed - * @param deserializer The `Deserializer` associated with the input `Writable` + * @param rawDeser The `Deserializer` associated with the input `Writable` * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding * positions in the output schema * @param mutableRow A reusable `MutableRow` that should be filled + * @param tableDeser Table Deserializer * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( iterator: Iterator[Writable], - deserializer: Deserializer, + rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow, + tableDeser: Deserializer): Iterator[Row] = { + + val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { + rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] + } else { + HiveShim.getConvertedOI( + rawDeser.getObjectInspector, + tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] + } - val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal }.unzip - // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern - // matching and branching costs per row. + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { _.getFieldObjectInspector match { case oi: BooleanObjectInspector => @@ -316,9 +331,11 @@ private[hive] object HadoopTableReader extends HiveInspectors { } } + val converter = ObjectInspectorConverters.getConverter(rawDeser.getObjectInspector, soi) + // Map each tuple to a row object iterator.map { value => - val raw = deserializer.deserialize(value) + val raw = converter.convert(rawDeser.deserialize(value)) var i = 0 while (i < fieldRefs.length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 381cd2a29123e..aa6fb42de7f88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -32,9 +32,12 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) +case class ThreeCloumntable(key: Int, value: String, key1: String) + class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ + val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -186,4 +189,43 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { sql("DROP TABLE hiveTableWithStructValue") } + + test("SPARK-5498:partition schema does not match table schema") { + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val testDatawithNull = TestHive.sparkContext.parallelize( + (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() + + val tmpDir = Files.createTempDir() + sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") + + // test schema the same between partition and table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + // test difference type of field + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + // add column to table + sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") + checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), + testDatawithNull.collect.toSeq + ) + + // change column name to table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") + checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + sql("DROP TABLE table_with_partition") + } } diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 30646ddbc29d8..0ed93c2c5b1fa 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} @@ -210,7 +210,7 @@ private[hive] object HiveShim { def getDataLocationPath(p: Partition) = p.getPartitionPath - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) + def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) def compatibilityBlackList = Seq( "decimal_.*", @@ -244,6 +244,12 @@ private[hive] object HiveShim { } } + def getConvertedOI( + inputOI: ObjectInspector, + outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) + } + def prepareWritable(w: Writable): Writable = { w } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index f9fcbdae15745..7577309900209 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import java.util import java.util.{ArrayList => JArrayList} import java.util.Properties import java.rmi.server.UID @@ -38,7 +39,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable @@ -400,7 +401,11 @@ private[hive] object HiveShim { Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) } } - + + def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) + } + /* * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that * is needed to initialize before serialization. From 73d57754dd23d84331c10355338a4240b3ac5fee Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Mar 2015 17:52:23 -0700 Subject: [PATCH 039/171] [SPARK-6326][SQL] Improve castStruct to be faster Current `castStruct` should be very slow. This pr slightly improves it. Author: Liang-Chi Hsieh Closes #5017 from viirya/faster_caststruct and squashes the following commits: 385d5b0 [Liang-Chi Hsieh] Further improved. 746fcfb [Liang-Chi Hsieh] Make castStruct faster. --- .../spark/sql/catalyst/expressions/Cast.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 9bde74ac22669..31f1a5fdc7e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -394,10 +394,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val casts = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } - // TODO: This is very slow! - buildCast[Row](_, row => Row(row.toSeq.zip(casts).map { - case (v, cast) => if (v == null) null else cast(v) - }: _*)) + // TODO: Could be faster? + val newRow = new GenericMutableRow(from.fields.size) + buildCast[Row](_, row => { + var i = 0 + while (i < row.length) { + val v = row(i) + newRow.update(i, if (v == null) null else casts(i)(v)) + i += 1 + } + newRow.copy() + }) } private[this] def cast(from: DataType, to: DataType): Any => Any = to match { From 328daf65f8e320d3bdfd12397a8f85891f1d14c7 Mon Sep 17 00:00:00 2001 From: DoingDone9 <799203320@qq.com> Date: Wed, 25 Mar 2015 18:41:59 -0700 Subject: [PATCH 040/171] [SPARK-6271][SQL] Sort these tokens in alphabetic order to avoid further duplicate in HiveQl Author: DoingDone9 <799203320@qq.com> Closes #4973 from DoingDone9/sort_token and squashes the following commits: 855fa10 [DoingDone9] Update HiveQl.scala c7080b3 [DoingDone9] Sort these tokens in alphabetic order to avoid further duplicate in HiveQl c87e8b6 [DoingDone9] Merge pull request #3 from apache/master cb1852d [DoingDone9] Merge pull request #2 from apache/master c3f046f [DoingDone9] Merge pull request #1 from apache/master --- .../org/apache/spark/sql/hive/HiveQl.scala | 88 ++++++++++--------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 51775eb4cd6a0..c45c4ad70fae9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -55,37 +55,8 @@ private[hive] case object NativePlaceholder extends Command /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( - "TOK_DESCFUNCTION", - "TOK_DESCDATABASE", - "TOK_SHOW_CREATETABLE", - "TOK_SHOWCOLUMNS", - "TOK_SHOW_TABLESTATUS", - "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", - "TOK_SHOWINDEXES", - "TOK_SHOWINDEXES", - "TOK_SHOWPARTITIONS", - "TOK_SHOW_TBLPROPERTIES", - - "TOK_LOCKTABLE", - "TOK_SHOWLOCKS", - "TOK_UNLOCKTABLE", - - "TOK_SHOW_ROLES", - "TOK_CREATEROLE", - "TOK_DROPROLE", - "TOK_GRANT", - "TOK_GRANT_ROLE", - "TOK_REVOKE", - "TOK_SHOW_GRANT", - "TOK_SHOW_ROLE_GRANT", - "TOK_SHOW_SET_ROLE", - - "TOK_CREATEFUNCTION", - "TOK_DROPFUNCTION", - - "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERDATABASE_OWNER", + "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", "TOK_ALTERTABLE_ADDCOLS", @@ -102,28 +73,61 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_CREATEDATABASE", - "TOK_CREATEFUNCTION", - "TOK_CREATEINDEX", - "TOK_DROPDATABASE", - "TOK_DROPINDEX", - "TOK_DROPTABLE_PROPERTIES", - "TOK_MSCK", - "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", "TOK_ALTERVIEW_PROPERTIES", "TOK_ALTERVIEW_RENAME", + + "TOK_CREATEDATABASE", + "TOK_CREATEFUNCTION", + "TOK_CREATEINDEX", + "TOK_CREATEROLE", "TOK_CREATEVIEW", - "TOK_DROPVIEW_PROPERTIES", + + "TOK_DESCDATABASE", + "TOK_DESCFUNCTION", + + "TOK_DROPDATABASE", + "TOK_DROPFUNCTION", + "TOK_DROPINDEX", + "TOK_DROPROLE", + "TOK_DROPTABLE_PROPERTIES", "TOK_DROPVIEW", - + "TOK_DROPVIEW_PROPERTIES", + "TOK_EXPORT", + + "TOK_GRANT", + "TOK_GRANT_ROLE", + "TOK_IMPORT", + "TOK_LOAD", - - "TOK_SWITCHDATABASE" + + "TOK_LOCKTABLE", + + "TOK_MSCK", + + "TOK_REVOKE", + + "TOK_SHOW_CREATETABLE", + "TOK_SHOW_GRANT", + "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLES", + "TOK_SHOW_SET_ROLE", + "TOK_SHOW_TABLESTATUS", + "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOWCOLUMNS", + "TOK_SHOWDATABASES", + "TOK_SHOWFUNCTIONS", + "TOK_SHOWINDEXES", + "TOK_SHOWLOCKS", + "TOK_SHOWPARTITIONS", + + "TOK_SWITCHDATABASE", + + "TOK_UNLOCKTABLE" ) // Commands that we do not need to explain. From 5ab6e9f0c06b763b06a0e1396bdb08a823146c32 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 25 Mar 2015 18:43:26 -0700 Subject: [PATCH 041/171] [SPARK-6202] [SQL] enable variable substitution on test framework Author: Daoyuan Wang Closes #4930 from adrian-wang/testvs and squashes the following commits: 2ce590f [Daoyuan Wang] add explicit function types b1d68bf [Daoyuan Wang] only substitute for parseSql 9c4a950 [Daoyuan Wang] add a comment explaining 18fb481 [Daoyuan Wang] enable variable substitute on test framework --- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index dc61e9d2e3522..a3497eadd67f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -23,6 +23,7 @@ import java.util.{Set => JavaSet} import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.RegexSerDe import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -153,8 +154,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r + val vs = new VariableSubstitution() + + // we should substitute variables in hql to pass the text to parseSql() as a parameter. + // Hive parser need substituted text. HiveContext.sql() does this but return a DataFrame, + // while we need a logicalPlan so we cannot reuse that. protected[hive] class HiveQLQueryExecution(hql: String) - extends this.QueryExecution(HiveQl.parseSql(hql)) { + extends this.QueryExecution(HiveQl.parseSql(vs.substitute(hiveconf, hql))) { def hiveExec(): Seq[String] = runSqlHive(hql) override def toString: String = hql + "\n" + super.toString } From e87bf3713e684fa83165a1036d76f7a84f043775 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei Date: Wed, 25 Mar 2015 19:15:30 -0700 Subject: [PATCH 042/171] =?UTF-8?q?The=20UT=20test=20of=20spark=20is=20fai?= =?UTF-8?q?led.=20Because=20there=20is=20a=20test=20in=20SQLQuerySuite=20a?= =?UTF-8?q?bout=20creating=20table=20=E2=80=9Ctest=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If the tests in "sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala" are running before CachedTableSuite.scala, the test("Drop cached table") will failed. Because the table test is created in SQLQuerySuite.scala ,and this table not droped. So when running "drop cached table", table test already exists. There is error info: 01:18:35.738 ERROR hive.ql.exec.DDLTask: org.apache.hadoop.hive.ql.metadata.HiveException: AlreadyExistsException(message:Table test already exists) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:616) at org.apache.hadoop.hive.ql.exec.DDLTask.createTable(DDLTask.java:4189) at org.apache.hadoop.hive.ql.exec.DDLTask.execute(DDLTask.java:281) at org.apache.hadoop.hive.ql.exec.Task.executeTask(Task.java:153) at org.apache.hadoop.hive.ql.exec.TaskRunner.runSequential(TaskRunner.java:85) at org.apache.hadoop.hive.ql.Driver.launchTask(Driver.java:1503) at org.apache.hadoop.hive.ql.Driver.execute(Driver.java:1270) at org.apache.hadoop.hive.ql.Driver.runInternal(Driver.java:1088) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:911) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:901)test” And the test about "create table test" in "sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala,is: test("SPARK-4825 save join to table") { val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() sql("CREATE TABLE test1 (key INT, value STRING)") testData.insertInto("test1") sql("CREATE TABLE test2 (key INT, value STRING)") testData.insertInto("test2") testData.insertInto("test2") sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( table("test"), sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) } Author: KaiXinXiaoLei Closes #5150 from KaiXinXiaoLei/testFailed and squashes the following commits: 7534b02 [KaiXinXiaoLei] The UT test of spark is failed. --- .../org/apache/spark/sql/hive/CachedTableSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 44d24273e722a..221a0c263d36c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -92,12 +92,12 @@ class CachedTableSuite extends QueryTest { } test("Drop cached table") { - sql("CREATE TABLE test(a INT)") - cacheTable("test") - sql("SELECT * FROM test").collect() - sql("DROP TABLE test") + sql("CREATE TABLE cachedTableTest(a INT)") + cacheTable("cachedTableTest") + sql("SELECT * FROM cachedTableTest").collect() + sql("DROP TABLE cachedTableTest") intercept[AnalysisException] { - sql("SELECT * FROM test").collect() + sql("SELECT * FROM cachedTableTest").collect() } } From 276ef1c3cfd44b5fc082e1a495fff22fbaf6add3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 25 Mar 2015 19:21:54 -0700 Subject: [PATCH 043/171] [SPARK-6463][SQL] AttributeSet.equal should compare size Previously this could result in sets compare equals when in fact the right was a subset of the left. Based on #5133 by sisihj Author: sisihj Author: Michael Armbrust Closes #5194 from marmbrus/pr/5133 and squashes the following commits: 5ed4615 [Michael Armbrust] fix imports d4cbbc0 [Michael Armbrust] Add test cases 0a0834f [sisihj] AttributeSet.equal should compare size --- .../catalyst/expressions/AttributeSet.scala | 3 +- .../expressions/AttributeSetSuite.scala | 82 +++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index f9ae85a5cfc1b..11b4eb5c888be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -58,7 +58,8 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any): Boolean = other match { - case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + case otherSet: AttributeSet => + otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala new file mode 100644 index 0000000000000..f2f3a84d19380 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.IntegerType + +class AttributeSetSuite extends FunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + val aSet = AttributeSet(aLower :: Nil) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + val bSet = AttributeSet(bUpper :: Nil) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + test("sanity check") { + assert(aUpper != aLower) + assert(bUpper != bLower) + } + + test("checks by id not name") { + assert(aSet.contains(aUpper) === true) + assert(aSet.contains(aLower) === true) + assert(aSet.contains(fakeA) === false) + + assert(aSet.contains(bUpper) === false) + assert(aSet.contains(bLower) === false) + } + + test("++ preserves AttributeSet") { + assert((aSet ++ bSet).contains(aUpper) === true) + assert((aSet ++ bSet).contains(aLower) === true) + } + + test("extracts all references references") { + val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) + assert(addSet.contains(aUpper)) + assert(addSet.contains(aLower)) + assert(addSet.contains(bUpper)) + assert(addSet.contains(bLower)) + } + + test("dedups attributes") { + assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) + } + + test("subset") { + assert(aSet.subsetOf(aAndBSet) === true) + assert(aAndBSet.subsetOf(aSet) === false) + } + + test("equality") { + assert(aSet != aAndBSet) + assert(aAndBSet != aSet) + assert(aSet != bSet) + assert(bSet != aSet) + + assert(aSet == aSet) + assert(aSet == AttributeSet(aUpper :: Nil)) + } +} From f535802977c5a3ce45894d89fdf59f8723f023c8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 26 Mar 2015 00:01:24 -0700 Subject: [PATCH 044/171] [SPARK-6536] [PySpark] Column.inSet() in Python ``` >>> df[df.name.inSet("Bob", "Mike")].collect() [Row(age=5, name=u'Bob')] >>> df[df.age.inSet([1, 2, 3])].collect() [Row(age=2, name=u'Alice')] ``` Author: Davies Liu Closes #5190 from davies/in and squashes the following commits: 6b73a47 [Davies Liu] Column.inSet() in Python --- python/pyspark/sql/dataframe.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5cb89da7a8ed5..bf7c47b7261a9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -985,6 +985,23 @@ def substr(self, startPos, length): __getslice__ = substr + def inSet(self, *cols): + """ A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jcols = ListConverter().convert(cols, sc._gateway._gateway_client) + jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + # order asc = _unary_op("asc", "Returns a sort expression based on the" " ascending order of the given column name.") From 5bbcd1304cfebba31ec6857a80d3825a40d02e83 Mon Sep 17 00:00:00 2001 From: azagrebin Date: Thu, 26 Mar 2015 00:25:04 -0700 Subject: [PATCH 045/171] [SPARK-6117] [SQL] add describe function to DataFrame for summary statis... Please review my solution for SPARK-6117 Author: azagrebin Closes #5073 from azagrebin/SPARK-6117 and squashes the following commits: f9056ac [azagrebin] [SPARK-6117] [SQL] create one aggregation and split it locally into resulting DF, colocate test data with test case ddb3950 [azagrebin] [SPARK-6117] [SQL] simplify implementation, add test for DF without numeric columns 9daf31e [azagrebin] [SPARK-6117] [SQL] add describe function to DataFrame for summary statistics --- .../org/apache/spark/sql/DataFrame.scala | 53 ++++++++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 45 ++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5aece166aad22..db561825e676b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType} import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -751,6 +751,57 @@ class DataFrame private[sql]( select(colNames :_*) } + /** + * Compute numerical statistics for given columns of this [[DataFrame]]: + * count, mean (avg), stddev (standard deviation), min, max. + * Each row of the resulting [[DataFrame]] contains column with statistic name + * and columns with statistic results for each given column. + * If no columns are given then computes for all numerical columns. + * + * {{{ + * df.describe("age", "height") + * + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // max 92.0 192.0 + * }}} + */ + @scala.annotation.varargs + def describe(cols: String*): DataFrame = { + + def stddevExpr(expr: Expression) = + Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + + val statistics = List[(String, Expression => Expression)]( + "count" -> Count, + "mean" -> Average, + "stddev" -> stddevExpr, + "min" -> Min, + "max" -> Max) + + val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + + val localAgg = if (aggCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + } + + agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) + } + } else { + statistics.map { case (name, _) => Row(name) } + } + + val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType))) + val rowRdd = sqlContext.sparkContext.parallelize(localAgg) + sqlContext.createDataFrame(rowRdd, schema) + } + /** * Returns the first `n` rows. * @group action diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c30ed694a62f0..afbedd1e5825d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -443,6 +443,51 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) } + test("describe") { + + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", 4, 4), + Row("mean", 33.0, 178.0), + Row("stddev", 16.583123951777, 10.0), + Row("min", 16, 164), + Row("max", 60, 192)) + + val emptyDescribeResult = Seq( + Row("count", 0, 0), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + + def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq + + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + + val describeAllCols = describeTestData.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + test("apply on query results (SPARK-5462)") { val df = testData.sqlContext.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) From 855cba8fe59ffe17b51ed00fbbb5d3d7cf17ade9 Mon Sep 17 00:00:00 2001 From: DoingDone9 <799203320@qq.com> Date: Thu, 26 Mar 2015 17:04:19 +0800 Subject: [PATCH 046/171] [SPARK-6546][Build] Using the wrong code that will make spark compile failed!! wrong code : val tmpDir = Files.createTempDir() not Files should Utils Author: DoingDone9 <799203320@qq.com> Closes #5198 from DoingDone9/FilesBug and squashes the following commits: 6e0140d [DoingDone9] Update InsertIntoHiveTableSuite.scala e57d23f [DoingDone9] Update InsertIntoHiveTableSuite.scala 802261c [DoingDone9] Merge pull request #7 from apache/master d00303b [DoingDone9] Merge pull request #6 from apache/master 98b134f [DoingDone9] Merge pull request #5 from apache/master 161cae3 [DoingDone9] Merge pull request #4 from apache/master c87e8b6 [DoingDone9] Merge pull request #3 from apache/master cb1852d [DoingDone9] Merge pull request #2 from apache/master c3f046f [DoingDone9] Merge pull request #1 from apache/master --- .../org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index aa6fb42de7f88..8011952e0d535 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -198,7 +198,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { val testDatawithNull = TestHive.sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") From f88f51bbd461e0a42ad7021147268509b9c3c56e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 26 Mar 2015 18:46:57 +0800 Subject: [PATCH 047/171] [SPARK-6465][SQL] Fix serialization of GenericRowWithSchema using kryo Author: Michael Armbrust Closes #5191 from marmbrus/kryoRowsWithSchema and squashes the following commits: bb83522 [Michael Armbrust] Fix serialization of GenericRowWithSchema using kryo f914f16 [Michael Armbrust] Add no arg constructor to GenericRowWithSchema --- .../spark/sql/catalyst/expressions/rows.scala | 7 +++++-- .../org/apache/spark/sql/types/Metadata.scala | 3 +++ .../org/apache/spark/sql/types/dataTypes.scala | 18 ++++++++++++++++++ .../sql/execution/SparkSqlSerializer.scala | 4 +--- .../scala/org/apache/spark/sql/RowSuite.scala | 12 ++++++++++++ 5 files changed, 39 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8bba26bc4cf7f..a8983df208318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -66,7 +66,7 @@ object EmptyRow extends Row { */ class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) @@ -172,11 +172,14 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { class GenericRowWithSchema(values: Array[Any], override val schema: StructType) extends GenericRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index e50e9761431f5..6ee24ee0c1913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -41,6 +41,9 @@ import org.apache.spark.annotation.DeveloperApi sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Tests whether this Metadata contains a binding for a key. */ def contains(key: String): Boolean = map.contains(key) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index d973144de3468..952cf5c75688d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -670,6 +670,10 @@ case class PrecisionInfo(precision: Int, scale: Int) */ @DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = Decimal.DecimalIsFractional @@ -819,6 +823,10 @@ object ArrayType { */ @DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") @@ -857,6 +865,9 @@ case class StructField( nullable: Boolean = true, metadata: Metadata = Metadata.empty) { + /** No-arg constructor for kryo. */ + protected def this() = this(null, null) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) @@ -1003,6 +1014,9 @@ object StructType { @DeveloperApi case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -1121,6 +1135,10 @@ case class MapType( keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + def this() = this(null, null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- key: ${keyType.typeName}\n") builder.append(s"$prefix-- value: ${valueType.typeName} " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index c4534fd5f67e4..967bd76b302d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHa private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { - val kryo = new Kryo() + val kryo = super.newKryo() kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) @@ -57,8 +57,6 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[Decimal]) kryo.setReferences(false) - kryo.setClassLoader(Utils.getSparkClassLoader) - new AllScalaRegistrar().apply(kryo) kryo } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index f5b945f468dad..36465cc2fa11a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.SparkSqlSerializer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends FunSuite { @@ -50,4 +53,13 @@ class RowSuite extends FunSuite { row(0) = null assert(row.isNullAt(0)) } + + test("serialize w/ kryo") { + val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() + val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val instance = serializer.newInstance() + val ser = instance.serialize(row) + val de = instance.deserialize(ser).asInstanceOf[Row] + assert(de === row) + } } From 0c88ce5416d7687bc806a7655e17009ad5823d30 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 26 Mar 2015 12:54:48 +0000 Subject: [PATCH 048/171] [SPARK-6468][Block Manager] Fix the race condition of subDirs in DiskBlockManager There are two race conditions of `subDirs` in `DiskBlockManager`: 1. `getAllFiles` does not use correct locks to read the contents in `subDirs`. Although it's designed for testing, it's still worth to add correct locks to eliminate the race condition. 2. The double-check has a race condition in `getFile(filename: String)`. If a thread finds `subDirs(dirId)(subDirId)` is not null out of the `synchronized` block, it may not be able to see the correct content of the File instance pointed by `subDirs(dirId)(subDirId)` according to the Java memory model (there is no volatile variable here). This PR fixed the above race conditions. Author: zsxwing Closes #5136 from zsxwing/SPARK-6468 and squashes the following commits: cbb872b [zsxwing] Fix the race condition of subDirs in DiskBlockManager --- .../spark/storage/DiskBlockManager.scala | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 12cd8ea3bdf1f..2883137872600 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -47,6 +47,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } + // The content of subDirs is immutable but the content of subDirs(i) is mutable. And the content + // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private val shutdownHook = addShutdownHook() @@ -61,20 +63,17 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon val subDirId = (hash / localDirs.length) % subDirsPerLocalDir // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - if (!newDir.exists() && !newDir.mkdir()) { - throw new IOException(s"Failed to create local dir in $newDir.") - } - subDirs(dirId)(subDirId) = newDir - newDir + val subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + if (!newDir.exists() && !newDir.mkdir()) { + throw new IOException(s"Failed to create local dir in $newDir.") } + subDirs(dirId)(subDirId) = newDir + newDir } } @@ -91,7 +90,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** List all the files currently stored on disk by the disk manager. */ def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories - subDirs.flatten.filter(_ != null).flatMap { dir => + subDirs.flatMap { dir => + dir.synchronized { + // Copy the content of dir because it may be modified in other threads + dir.clone() + } + }.filter(_ != null).flatMap { dir => val files = dir.listFiles() if (files != null) files else Seq.empty } From 1c05027a143d1b0bf3df192984e6cac752b1e926 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Thu, 26 Mar 2015 21:13:38 +0800 Subject: [PATCH 049/171] [SQL][SPARK-6471]: Metastore schema should only be a subset of parquet schema to support dropping of columns using replace columns Currently in the parquet relation 2 implementation, error is thrown in case merged schema is not exactly the same as metastore schema. But to support cases like deletion of column using replace column command, we can relax the restriction so that even if metastore schema is a subset of merged parquet schema, the query will work. Author: Yash Datta Closes #5141 from saucam/replace_col and squashes the following commits: e858d5b [Yash Datta] SPARK-6471: Fix test cases, add a new test case for metastore schema to be subset of parquet schema 5f2f467 [Yash Datta] SPARK-6471: Metastore schema should only be a subset of parquet schema to support dropping of columns using replace columns --- .../apache/spark/sql/parquet/newParquet.scala | 5 +++-- .../spark/sql/parquet/ParquetSchemaSuite.scala | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 410600b0529d3..3516cfe680c61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -758,12 +758,13 @@ private[sql] object ParquetRelation2 extends Logging { |${parquetSchema.prettyJson} """.stripMargin - assert(metastoreSchema.size == parquetSchema.size, schemaConflictMessage) + assert(metastoreSchema.size <= parquetSchema.size, schemaConflictMessage) val ordinalMap = metastoreSchema.zipWithIndex.map { case (field, index) => field.name.toLowerCase -> index }.toMap - val reorderedParquetSchema = parquetSchema.sortBy(f => ordinalMap(f.name.toLowerCase)) + val reorderedParquetSchema = parquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) StructType(metastoreSchema.zip(reorderedParquetSchema).map { // Uses Parquet field names but retains Metastore data types. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 321832cd43211..8462f9bb2d620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -212,8 +212,11 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("UPPERCase", IntegerType, nullable = true)))) } - // Conflicting field count - assert(intercept[Throwable] { + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -221,6 +224,17 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructType(Seq( StructField("lowerCase", BinaryType), StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Conflicting field count + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) // Conflicting field names From 3ddb975faeddeb2674a7e7f7e80cf90dfbd4d6d2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 26 Mar 2015 13:27:05 +0000 Subject: [PATCH 050/171] [MLlib]remove unused import minor thing. Let me know if jira is required. Author: Yuhao Yang Closes #5207 from hhbyyh/adjustImport and squashes the following commits: 2240121 [Yuhao Yang] remove unused import --- .../src/main/scala/org/apache/spark/mllib/clustering/LDA.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 5e17c8da61134..9d63a08e211bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} +import breeze.linalg.{DenseVector => BDV, normalize} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental From fe15ea976073edd738c006af1eb8d31617a039fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 26 Mar 2015 15:00:23 +0000 Subject: [PATCH 051/171] SPARK-6480 [CORE] histogram() bucket function is wrong in some simple edge cases Fix fastBucketFunction for histogram() to handle edge conditions more correctly. Add a test, and fix existing one accordingly Author: Sean Owen Closes #5148 from srowen/SPARK-6480 and squashes the following commits: 974a0a0 [Sean Owen] Additional test of huge ranges, and a few more comments (and comment fixes) 23ec01e [Sean Owen] Fix fastBucketFunction for histogram() to handle edge conditions more correctly. Add a test, and fix existing one accordingly --- .../apache/spark/rdd/DoubleRDDFunctions.scala | 20 +++++++--------- .../org/apache/spark/rdd/DoubleRDDSuite.scala | 24 +++++++++++++++---- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 03afc289736bb..71e6e300fec5f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -191,25 +191,23 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced - def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = { // If our input is not a number unless the increment is also NaN then we fail fast - if (e.isNaN()) { - return None - } - val bucketNumber = (e - min)/(increment) - // We do this rather than buckets.lengthCompare(bucketNumber) - // because Array[Double] fails to override it (for now). - if (bucketNumber > count || bucketNumber < 0) { + if (e.isNaN || e < min || e > max) { None } else { - Some(bucketNumber.toInt.min(count - 1)) + // Compute ratio of e's distance along range to total range first, for better precision + val bucketNumber = (((e - min) / (max - min)) * count).toInt + // should be less than count, but will equal count if e == max, in which case + // it's part of the last end-range-inclusive bucket, so return count-1 + Some(math.min(bucketNumber, count - 1)) } } // Decide which bucket function to pass to histogramPartition. We decide here - // rather than having a general function so that the decission need only be made + // rather than having a general function so that the decision need only be made // once rather than once per shard val bucketFunction = if (evenBuckets) { - fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _ } else { basicBucketFunction _ } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 4cd0f97368ca3..97079382c716f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -235,6 +235,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithDoubleValuesAtMinMax") { + val rdd = sc.parallelize(Seq(1, 1, 1, 2, 3, 3)) + assert(Array(3, 0, 1, 2) === rdd.map(_.toDouble).histogram(4)._2) + assert(Array(3, 1, 2) === rdd.map(_.toDouble).histogram(3)._2) + } + test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2)) @@ -248,7 +254,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { } test("WorksWithoutBucketsForLargerDatasets") { - // Verify the case of slighly larger datasets + // Verify the case of slightly larger datasets val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(8) val expectedHistogramResults = @@ -259,17 +265,27 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } - test("WorksWithoutBucketsWithIrrationalBucketEdges") { - // Verify the case of buckets with irrational edges. See #SPARK-2862. + test("WorksWithoutBucketsWithNonIntegralBucketEdges") { + // Verify the case of buckets with nonintegral edges. See #SPARK-2862. val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(9) + // Buckets are 6.0, 16.333333333333336, 26.666666666666668, 37.0, 47.333333333333336 ... val expectedHistogramResults = - Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + Array(11, 10, 10, 11, 10, 10, 11, 10, 11) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets(0) === 6.0) assert(histogramBuckets(9) === 99.0) } + test("WorksWithHugeRange") { + val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30)) + val histogramResults = rdd.histogram(1000000)._2 + assert(histogramResults(0) === 1) + assert(histogramResults(1) === 1) + assert(histogramResults.last === 1) + assert((2 to histogramResults.length - 2).forall(i => histogramResults(i) == 0)) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity From c3a52a08248db08eade29b265f02483144a282d6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 26 Mar 2015 10:52:31 -0700 Subject: [PATCH 052/171] SPARK-6532 [BUILD] LDAModel.scala fails scalastyle on Windows Use standard UTF-8 source / report encoding for scalastyle Author: Sean Owen Closes #5211 from srowen/SPARK-6532 and squashes the following commits: 16a33e5 [Sean Owen] Use standard UTF-8 source / report encoding for scalastyle --- pom.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 23bb16130b504..b3cecd1893a06 100644 --- a/pom.xml +++ b/pom.xml @@ -1452,7 +1452,8 @@ ${basedir}/src/test/scala scalastyle-config.xml scalastyle-output.xml - UTF-8 + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} From 784fcd532784fcfd9bf0a1db71c9f71c469ee716 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 26 Mar 2015 12:26:13 -0700 Subject: [PATCH 053/171] [SPARK-6117] [SQL] Improvements to DataFrame.describe() 1. Slightly modifications to the code to make it more readable. 2. Added Python implementation. 3. Updated the documentation to state that we don't guarantee the output schema for this function and it should only be used for exploratory data analysis. Author: Reynold Xin Closes #5201 from rxin/df-describe and squashes the following commits: 25a7834 [Reynold Xin] Reset run-tests. 6abdfee [Reynold Xin] [SPARK-6117] [SQL] Improvements to DataFrame.describe() --- python/pyspark/sql/dataframe.py | 19 ++++++++ .../org/apache/spark/sql/DataFrame.scala | 46 +++++++++++-------- .../org/apache/spark/sql/DataFrameSuite.scala | 3 +- 3 files changed, 48 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bf7c47b7261a9..d51309f7ef5aa 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -520,6 +520,25 @@ def sort(self, *cols): orderBy = sort + def describe(self, *cols): + """Computes statistics for numeric columns. + + This include count, mean, stddev, min, and max. If no columns are + given, this function computes statistics for all numerical columns. + + >>> df.describe().show() + summary age + count 2 + mean 3.5 + stddev 1.5 + min 2 + max 5 + """ + cols = ListConverter().convert(cols, + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) + return DataFrame(jdf, self.sql_ctx) + def head(self, n=None): """ Return the first `n` rows or the first row if n is None. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index db561825e676b..4c80359cf07af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -752,15 +752,17 @@ class DataFrame private[sql]( } /** - * Compute numerical statistics for given columns of this [[DataFrame]]: - * count, mean (avg), stddev (standard deviation), min, max. - * Each row of the resulting [[DataFrame]] contains column with statistic name - * and columns with statistic results for each given column. - * If no columns are given then computes for all numerical columns. + * Computes statistics for numeric columns, including count, mean, stddev, min, and max. + * If no columns are given, this function computes statistics for all numerical columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to + * programmatically compute summary statistics, use the `agg` function instead. * * {{{ - * df.describe("age", "height") + * df.describe("age", "height").show() * + * // output: * // summary age height * // count 10.0 10.0 * // mean 53.3 178.05 @@ -768,13 +770,17 @@ class DataFrame private[sql]( * // min 18.0 163.0 * // max 92.0 192.0 * }}} + * + * @group action */ @scala.annotation.varargs def describe(cols: String*): DataFrame = { - def stddevExpr(expr: Expression) = + // TODO: Add stddev as an expression, and remove it from here. + def stddevExpr(expr: Expression): Expression = Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, @@ -782,24 +788,28 @@ class DataFrame private[sql]( "min" -> Min, "max" -> Max) - val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList - val localAgg = if (aggCols.nonEmpty) { + val ret: Seq[Row] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => - aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) } - agg(aggExprs.head, aggExprs.tail: _*).head().toSeq - .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) + val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + row.grouped(outputCols.size).toSeq.zip(statistics).map { + case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) } } else { + // If there are no output columns, just output a single column that contains the stats. statistics.map { case (name, _) => Row(name) } } - val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType))) - val rowRdd = sqlContext.sparkContext.parallelize(localAgg) - sqlContext.createDataFrame(rowRdd, schema) + // The first column is string type, and the rest are double type. + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes + LocalRelation(schema, ret) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index afbedd1e5825d..fbc4065a9666c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -444,7 +444,6 @@ class DataFrameSuite extends QueryTest { } test("describe") { - val describeTestData = Seq( ("Bob", 16, 176), ("Alice", 32, 164), @@ -465,7 +464,7 @@ class DataFrameSuite extends QueryTest { Row("min", null, null), Row("max", null, null)) - def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) val describeTwoCols = describeTestData.describe("age", "height") assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) From 71a0d40ebd37c80d8020e184366778b57c762285 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 26 Mar 2015 13:11:37 -0700 Subject: [PATCH 054/171] [SPARK-6554] [SQL] Don't push down predicates which reference partition column(s) There are two cases for the new Parquet data source: 1. Partition columns exist in the Parquet data files We don't need to push-down these predicates since partition pruning already handles them. 1. Partition columns don't exist in the Parquet data files We can't push-down these predicates since they are considered as invalid columns by Parquet. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5210) Author: Cheng Lian Closes #5210 from liancheng/spark-6554 and squashes the following commits: 4f7ec03 [Cheng Lian] Adds comments e134ced [Cheng Lian] Don't push down predicates which reference partition column(s) --- .../apache/spark/sql/parquet/newParquet.scala | 17 ++++++++++++----- .../spark/sql/parquet/ParquetFilterSuite.scala | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 3516cfe680c61..0d68810ec6043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -435,11 +435,18 @@ private[sql] case class ParquetRelation2( // Push down filters when possible. Notice that not all filters can be converted to Parquet // filter predicate. Here we try to convert each individual predicate and only collect those // convertible ones. - predicates - .flatMap(ParquetFilters.createFilter) - .reduceOption(FilterApi.and) - .filter(_ => sqlContext.conf.parquetFilterPushDown) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + if (sqlContext.conf.parquetFilterPushDown) { + predicates + // Don't push down predicates which reference partition columns + .filter { pred => + val partitionColNames = partitionColumns.map(_.name).toSet + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } + .flatMap(ParquetFilters.createFilter) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + } if (isPartitioned) { logInfo { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4d32e84fc1115..6a2c2a7c4080a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -321,6 +321,23 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA override protected def afterAll(): Unit = { sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } + + test("SPARK-6554: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.parquetFile(path).filter("part = 1"), + (1 to 3).map(i => Row(i, i.toString, 1))) + } + } + } } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { From aad00322765d6041e817a6bd3fcff2187d212057 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 26 Mar 2015 14:51:46 -0700 Subject: [PATCH 055/171] [DOCS][SQL] Fix JDBC example Author: Michael Armbrust Closes #5192 from marmbrus/fixJDBCDocs and squashes the following commits: b48a33d [Michael Armbrust] [DOCS][SQL] Fix JDBC example --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c99a0b03442c4..4441d6a000a02 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1406,7 +1406,7 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load("jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") {% endhighlight %} From 39fb57968352549f2276ac4fcd2b92988ed6fe42 Mon Sep 17 00:00:00 2001 From: Brennon York Date: Thu, 26 Mar 2015 19:08:09 -0700 Subject: [PATCH 056/171] [SPARK-6510][GraphX]: Add Graph#minus method to act as Set#difference Adds a `Graph#minus` method which will return only unique `VertexId`'s from the calling `VertexRDD`. To demonstrate a basic example with pseudocode: ``` Set((0L,0),(1L,1)).minus(Set((1L,1),(2L,2))) > Set((0L,0)) ``` Author: Brennon York Closes #5175 from brennonyork/SPARK-6510 and squashes the following commits: 248d5c8 [Brennon York] added minus(VertexRDD[VD]) method to avoid createUsingIndex and updated the mask operations to simplify with andNot call 3fb7cce [Brennon York] updated graphx doc to reflect the addition of minus method 6575d92 [Brennon York] updated mima exclude aaa030b [Brennon York] completed graph#minus functionality 7227c0f [Brennon York] beginning work on minus functionality --- docs/graphx-programming-guide.md | 2 ++ .../org/apache/spark/graphx/VertexRDD.scala | 16 +++++++++ .../graphx/impl/VertexPartitionBaseOps.scala | 15 +++++++++ .../spark/graphx/impl/VertexRDDImpl.scala | 25 ++++++++++++++ .../apache/spark/graphx/VertexRDDSuite.scala | 33 +++++++++++++++++-- project/MimaExcludes.scala | 3 ++ 6 files changed, 92 insertions(+), 2 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index c601d793a2e9a..3f10cb2dc3d2a 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -899,6 +899,8 @@ class VertexRDD[VD] extends RDD[(VertexID, VD)] { // Transform the values without changing the ids (preserves the internal index) def mapValues[VD2](map: VD => VD2): VertexRDD[VD2] def mapValues[VD2](map: (VertexId, VD) => VD2): VertexRDD[VD2] + // Show only vertices unique to this set based on their VertexId's + def minus(other: RDD[(VertexId, VD)]) // Remove vertices from this set that appear in the other set def diff(other: VertexRDD[VD]): VertexRDD[VD] // Join operators that take advantage of the internal indexing to accelerate joins (substantially) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index ad4bfe077293a..a9f04b559c3d1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -121,6 +121,22 @@ abstract class VertexRDD[VD]( */ def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] + /** + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other an RDD to run the set operation against + */ + def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] + + /** + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other a VertexRDD to run the set operation against + */ + def minus(other: VertexRDD[VD]): VertexRDD[VD] + /** * For each vertex present in both `this` and `other`, `diff` returns only those vertices with * differing values; for values that are different, keeps the values from `other`. This is diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index 4fd2548b7faf6..b90f9fa327052 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -88,6 +88,21 @@ private[graphx] abstract class VertexPartitionBaseOps this.withMask(newMask) } + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Self[VD]): Self[VD] = { + if (self.index != other.index) { + logWarning("Minus operations on two VertexPartitions with different indexes is slow.") + minus(createUsingIndex(other.iterator)) + } else { + self.withMask(self.mask.andNot(other.mask)) + } + } + + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Iterator[(VertexId, VD)]): Self[VD] = { + minus(createUsingIndex(other)) + } + /** * Hides vertices that are the same between this and other. For vertices that are different, keeps * the values from `other`. The indices of `this` and `other` must be the same. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 125692ddaad83..349c8545bf201 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -103,6 +103,31 @@ class VertexRDDImpl[VD] private[graphx] ( override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = this.mapVertexPartitions(_.map(f)) + override def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { + minus(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) + } + + override def minus (other: VertexRDD[VD]): VertexRDD[VD] = { + other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true) { + (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.minus(otherPart)) + }) + case _ => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.minus(msgs)) + } + ) + } + } + override def diff(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { diff(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index 4f7a442ab503d..c9443d11c76cf 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -47,6 +47,35 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("minus") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + } + } + + test("minus with RDD[(VertexId, VD)]") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB: RDD[(VertexId, Int)] = + sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + } + } + + test("minus with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 5).map(i => (i.toLong, 0))) + val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 50).toSet) + } + } + test("diff") { withSpark { sc => val n = 100 @@ -71,7 +100,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } - test("diff vertices with the non-equal number of partitions") { + test("diff vertices with non-equal number of partitions") { withSpark { sc => val vertexA = VertexRDD(sc.parallelize(0 until 24, 3).map(i => (i.toLong, 0))) val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1))) @@ -96,7 +125,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } - test("leftJoin vertices with the non-equal number of partitions") { + test("leftJoin vertices with non-equal number of partitions") { withSpark { sc => val vertexA = VertexRDD(sc.parallelize(0 until 100, 2).map(i => (i.toLong, 1))) val vertexB = VertexRDD( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 56f5dbe53fad4..b9f40046e15a2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -51,6 +51,9 @@ object MimaExcludes { "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") ) case v if v.startsWith("1.3") => From 49d2ec63eccec8a3a78b15b583c36f84310fc6f0 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 26 Mar 2015 22:48:42 -0700 Subject: [PATCH 057/171] [SPARK-6405] Limiting the maximum Kryo buffer size to be 2GB. Kryo buffers are backed by byte arrays, but primitive arrays can only be up to 2GB in size. It is misleading to allow users to set buffers past this size. Author: mcheah Closes #5218 from mccheah/feature/limit-kryo-buffer and squashes the following commits: 1d6d1be [mcheah] Fixing numeric typo e2e30ce [mcheah] Removing explicit int and double type to match style 09fd80b [mcheah] Should be >= not >. Slightly more consistent error message. 60634f9 [mcheah] [SPARK-6405] Limiting the maximum Kryo buffer size to be 2GB. --- .../apache/spark/serializer/KryoSerializer.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index f83bcaa5cc09e..579fb6624e692 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,10 +49,20 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = - (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) + if (bufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + + s"2048 mb, got: + $bufferSizeMb mb.") + } + private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + + val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + if (maxBufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + s"2048 mb, got: + $maxBufferSizeMb mb.") + } + private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 - private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val userRegistrator = conf.getOption("spark.kryo.registrator") From f43a61031fd7d9d4fab3d8ac584e7b4c7c5e1035 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 27 Mar 2015 00:15:02 -0700 Subject: [PATCH 058/171] [SPARK-6341][mllib] Upgrade breeze from 0.11.1 to 0.11.2 There are any bugs of breeze's SparseVector at 0.11.1. You know, Spark 1.3 depends on breeze 0.11.1. So I think we should upgrade it to 0.11.2. https://issues.apache.org/jira/browse/SPARK-6341 And thanks you for your great cooperation, David Hall(dlwh) Author: Yu ISHIKAWA Closes #5222 from yu-iskw/upgrade-breeze and squashes the following commits: ad8a688 [Yu ISHIKAWA] Upgrade breeze from 0.11.1 to 0.11.2 because of a bug of SparseVector. Thanks you for your great cooperation, David Hall(@dlwh) --- mllib/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index 4c183543e3fa8..5dfab36c76907 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -64,7 +64,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.11.1 + 0.11.2 From da546b7ba03d84d7f6af97fe04471b12f5b3392f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 27 Mar 2015 12:31:06 +0000 Subject: [PATCH 059/171] [SPARK-6556][Core] Fix wrong parsing logic of executorTimeoutMs and checkTimeoutIntervalMs in HeartbeatReceiver The current reading logic of `executorTimeoutMs` is: ```Scala private val executorTimeoutMs = sc.conf.getLong("spark.network.timeout", sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120)) * 1000 ``` So if `spark.storage.blockManagerSlaveTimeoutMs` is 10000 and `spark.network.timeout` is not set, executorTimeoutMs will be 10000 * 1000. But the correct value should have been 10000. `checkTimeoutIntervalMs` has the same issue. This PR fixes them. Author: zsxwing Closes #5209 from zsxwing/SPARK-6556 and squashes the following commits: 6a0a411 [zsxwing] Fix docs c7d5422 [zsxwing] Add comments for executorTimeoutMs and checkTimeoutIntervalMs ccd5147 [zsxwing] Fix wrong parsing logic of executorTimeoutMs and checkTimeoutIntervalMs in HeartbeatReceiver --- .../org/apache/spark/HeartbeatReceiver.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 715f292f03469..548dcb93c3358 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -49,12 +49,17 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule // executor ID -> timestamp of when the last heartbeat from this executor was received private val executorLastSeen = new mutable.HashMap[String, Long] - - private val executorTimeoutMs = sc.conf.getLong("spark.network.timeout", - sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120)) * 1000 - - private val checkTimeoutIntervalMs = sc.conf.getLong("spark.network.timeoutInterval", - sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60)) * 1000 + + // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses + // "milliseconds" + private val executorTimeoutMs = sc.conf.getOption("spark.network.timeout").map(_.toLong * 1000). + getOrElse(sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120000)) + + // "spark.network.timeoutInterval" uses "seconds", while + // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" + private val checkTimeoutIntervalMs = + sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000). + getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)) private var timeoutCheckingTask: Cancellable = null From aa2b9917489f9bbb02c8acea5ff43335042e2705 Mon Sep 17 00:00:00 2001 From: Dean Chen Date: Fri, 27 Mar 2015 14:32:51 +0000 Subject: [PATCH 060/171] [SPARK-6544][build] Increment Avro version from 1.7.6 to 1.7.7 Fixes bug causing Kryo serialization to fail with Avro files in between stages. https://issues.apache.org/jira/browse/AVRO-1476?focusedCommentId=13999249&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-13999249 Author: Dean Chen Closes #5193 from deanchen/SPARK-6544 and squashes the following commits: 813d4c5 [Dean Chen] [SPARK-6544][build] Increment Avro version from 1.7.6 to 1.7.7 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index b3cecd1893a06..3eb3da2cd8af3 100644 --- a/pom.xml +++ b/pom.xml @@ -141,7 +141,7 @@ 2.4.0 2.0.8 3.1.0 - 1.7.6 + 1.7.7 0.7.1 1.8.3 From 5d9c37c23d1edd91e6c5561780006b762cde5f66 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 27 Mar 2015 11:40:00 -0700 Subject: [PATCH 061/171] [SPARK-6550][SQL] Use analyzed plan in DataFrame This is based on bug and test case proposed by viirya. See #5203 for a excellent description of the problem. TLDR; The problem occurs because the function `groupBy(String)` calls `resolve`, which returns an `AttributeReference`. However, this `AttributeReference` is based on an analyzed plan which is thrown away. At execution time, we once again analyze the plan. However, in the case of self-joins, each call to analyze will produce a new tree for the left side of the join, rendering the previously returned `AttributeReference` invalid. As a fix, I propose we keep the analyzed plan instead of the unresolved plan inside of a `DataFrame`. Author: Michael Armbrust Closes #5217 from marmbrus/preanalyzer and squashes the following commits: 1f98e2d [Michael Armbrust] revert change dd4dec1 [Michael Armbrust] Use the analyzed plan in DataFrame 089c52e [Michael Armbrust] WIP --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4c80359cf07af..423ef3912bc89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -146,7 +146,7 @@ class DataFrame private[sql]( _: WriteToFile => LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => - queryExecution.logical + queryExecution.analyzed } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index fbc4065a9666c..5f03805d70416 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -113,6 +113,10 @@ class DataFrameSuite extends QueryTest { checkAnswer( df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } test("explode") { From 887e1b72dfa5965f8ab1aad212fb33bb365b0e1b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 27 Mar 2015 11:42:26 -0700 Subject: [PATCH 062/171] [SPARK-6574] [PySpark] fix sql example Fix the import in sql example. Author: Davies Liu Closes #5230 from davies/fix_sql_example and squashes the following commits: 7ecc5f4 [Davies Liu] fix sql example --- examples/src/main/python/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 47202fde7510b..d89361f324917 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -19,7 +19,7 @@ from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.sql import Row, StructField, StructType, StringType, IntegerType +from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType if __name__ == "__main__": From d5497ab1343e4d1b2a1c336f2e3520d74c6674a1 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 27 Mar 2015 13:29:10 -0700 Subject: [PATCH 063/171] [SPARK-6526][ML] Add Normalizer transformer in ML package See [SPARK-6526](https://issues.apache.org/jira/browse/SPARK-6526). mengxr Should we add test suite for this transformer? There is no test suite for all feature transformers in ML package now. Author: Xusen Yin Closes #5181 from yinxusen/SPARK-6526 and squashes the following commits: 6faa7bf [Xusen Yin] fix style 8a462da [Xusen Yin] remove duplications ab35ab0 [Xusen Yin] add test suite bc8cd0f [Xusen Yin] fix comment 79774c9 [Xusen Yin] add Normalizer transformer in ML package --- .../apache/spark/ml/feature/Normalizer.scala | 53 +++++++++ .../spark/ml/feature/NormalizerSuite.scala | 109 ++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala new file mode 100644 index 0000000000000..05f91dc9105fe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{DoubleParam, ParamMap} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.sql.types.DataType + +/** + * :: AlphaComponent :: + * Normalize a vector to have unit norm using the given p-norm. + */ +@AlphaComponent +class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { + + /** + * Normalization in L^p^ space, p = 2 by default. + * @group param + */ + val p = new DoubleParam(this, "p", "the p norm value", Some(2)) + + /** @group getParam */ + def getP: Double = get(p) + + /** @group setParam */ + def setP(value: Double): this.type = set(p, value) + + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { + val normalizer = new feature.Normalizer(paramMap(p)) + normalizer.transform + } + + override protected def outputDataType: DataType = new VectorUDT() +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala new file mode 100644 index 0000000000000..a18c335952b96 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +private case class DataSet(features: Vector) + +class NormalizerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var data: Array[Vector] = _ + @transient var dataFrame: DataFrame = _ + @transient var normalizer: Normalizer = _ + @transient var l1Normalized: Array[Vector] = _ + @transient var l2Normalized: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + l1Normalized = Array( + Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.12765957, -0.23404255, -0.63829787), + Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))), + Vectors.dense(0.625, 0.07894737, 0.29605263), + Vectors.sparse(3, Seq()) + ) + l2Normalized = Array( + Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.184549876, -0.3383414, -0.922749378), + Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))), + Vectors.dense(0.897906166, 0.113419726, 0.42532397), + Vectors.sparse(3, Seq()) + ) + + val sqlContext = new SQLContext(sc) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet)) + normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normalized_features") + } + + def collectResult(result: DataFrame): Array[Vector] = { + result.select("normalized_features").collect().map { + case Row(features: Vector) => features + } + } + + def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + } + + def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { (vector1, vector2) => + vector1 ~== vector2 absTol 1E-5 + }, "The vector value is not correct after normalization.") + } + + test("Normalization with default parameter") { + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l2Normalized) + } + + test("Normalization with setter") { + normalizer.setP(1) + + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l1Normalized) + } +} From 3af7334304341fba091aa39ce2efbdfd167c697b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 27 Mar 2015 14:56:57 -0700 Subject: [PATCH 064/171] [SPARK-6564][SQL] SQLContext.emptyDataFrame should contain 0 row, not 1 row Author: Reynold Xin Closes #5226 from rxin/empty-df and squashes the following commits: 1306d88 [Reynold Xin] Proper fix. e135bb9 [Reynold Xin] [SPARK-6564][SQL] SQLContext.emptyDataFrame should contain 0 rows, not 1 row. --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 5 ++++- .../sql/catalyst/optimizer/ExpressionOptimizationSuite.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 4 ++-- .../org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 5 +++++ .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 4 ++-- 8 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index ea7d44a3723d1..b176f7e729a42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -139,7 +139,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { sortType.? ~ (LIMIT ~> expression).? ^^ { case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) + val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g .map(Aggregate(_, assignAliases(p), withFilter)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4d9e41a2b5d85..190209238a4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -287,7 +287,10 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -case object NoRelation extends LeafNode { +/** + * A relation with one row. This is used in "SELECT ..." without a from clause. + */ +case object OneRowRelation extends LeafNode { override def output: Seq[Attribute] = Nil /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index ae99a3f9ba287..2f3704be59a9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -29,7 +29,7 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, NoRelation) + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) val optimizedPlan = DefaultOptimizer(plan) super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 48884040bfce7..129d091ca03e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{NoRelation, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ /** @@ -55,6 +55,6 @@ class PlanTest extends FunSuite { /** Fails the test if the two expressions do not match */ protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, NoRelation), Filter(e2, NoRelation)) + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e59cf9b9e037b..b8100782ec937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} @@ -177,7 +177,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental @transient - lazy val emptyDataFrame = DataFrame(this, NoRelation) + lazy val emptyDataFrame: DataFrame = createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) /** * A collection of methods for registering user-defined functions (UDF). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2b581152e5f77..f754fa770d1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -296,7 +296,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil - case logical.NoRelation => + case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5f03805d70416..6761d996fd975 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -84,6 +84,11 @@ class DataFrameSuite extends QueryTest { testData.collect().toSeq) } + test("empty data frame") { + assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(TestSQLContext.emptyDataFrame.count() === 0) + } + test("head and take") { assert(testData.take(2) === testData.collect().take(2)) assert(testData.head(2) === testData.collect().take(2)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c45c4ad70fae9..cd8e7c09eea5b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -479,7 +479,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(NoRelation) + ExplainCommand(OneRowRelation) case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.getText => val Some(crtTbl) :: _ :: extended :: Nil = @@ -622,7 +622,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val relations = fromClause match { case Some(f) => nodeToRelation(f) - case None => NoRelation + case None => OneRowRelation } val withWhere = whereClause.map { whereNode => From 5909f0973de15f685836c2828e6d4c38f57d2c19 Mon Sep 17 00:00:00 2001 From: Adam Budde Date: Sat, 28 Mar 2015 09:14:09 +0800 Subject: [PATCH 065/171] [SPARK-6538][SQL] Add missing nullable Metastore fields when merging a Parquet schema Opening to replace #5188. When Spark SQL infers a schema for a DataFrame, it will take the union of all field types present in the structured source data (e.g. an RDD of JSON data). When the source data for a row doesn't define a particular field on the DataFrame's schema, a null value will simply be assumed for this field. This workflow makes it very easy to construct tables and query over a set of structured data with a nonuniform schema. However, this behavior is not consistent in some cases when dealing with Parquet files and an external table managed by an external Hive metastore. In our particular usecase, we use Spark Streaming to parse and transform our input data and then apply a window function to save an arbitrary-sized batch of data as a Parquet file, which itself will be added as a partition to an external Hive table via an *"ALTER TABLE... ADD PARTITION..."* statement. Since our input data is nonuniform, it is expected that not every partition batch will contain every field present in the table's schema obtained from the Hive metastore. As such, we expect that the schema of some of our Parquet files may not contain the same set fields present in the full metastore schema. In such cases, it seems natural that Spark SQL would simply assume null values for any missing fields in the partition's Parquet file, assuming these fields are specified as nullable by the metastore schema. This is not the case in the current implementation of ParquetRelation2. The **mergeMetastoreParquetSchema()** method used to reconcile differences between a Parquet file's schema and a schema retrieved from the Hive metastore will raise an exception if the Parquet file doesn't match the same set of fields specified by the metastore. This pull requests alters the behavior of **mergeMetastoreParquetSchema()** by having it first add any nullable fields from the metastore schema to the Parquet file schema if they aren't already present there. Author: Adam Budde Closes #5214 from budde/nullable-fields and squashes the following commits: a52d378 [Adam Budde] Refactor ParquetSchemaSuite.scala for cases now permitted by SPARK-6471 and SPARK-6538 9041bfa [Adam Budde] Add missing nullable Metastore fields when merging a Parquet schema --- .../apache/spark/sql/parquet/newParquet.scala | 32 ++++++++++++++- .../sql/parquet/ParquetSchemaSuite.scala | 40 +++++++++++++++++-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 0d68810ec6043..53f765ee26a13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -765,12 +765,14 @@ private[sql] object ParquetRelation2 extends Logging { |${parquetSchema.prettyJson} """.stripMargin - assert(metastoreSchema.size <= parquetSchema.size, schemaConflictMessage) + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) val ordinalMap = metastoreSchema.zipWithIndex.map { case (field, index) => field.name.toLowerCase -> index }.toMap - val reorderedParquetSchema = parquetSchema.sortBy(f => + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) StructType(metastoreSchema.zip(reorderedParquetSchema).map { @@ -782,6 +784,32 @@ private[sql] object ParquetRelation2 extends Logging { }) } + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } + + // TODO Data source implementations shouldn't touch Catalyst types (`Literal`). // However, we are already using Catalyst expressions for partition pruning and predicate // push-down here... diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 8462f9bb2d620..61f1cf347ab0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -226,22 +226,54 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("UPPERCase", IntegerType, nullable = true)))) } - // Conflicting field count + // Metastore schema contains additional non-nullable fields. assert(intercept[Throwable] { ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false), - StructField("lowerCase", BinaryType))), + StructField("lowerCase", BinaryType, nullable = false))), StructType(Seq( StructField("UPPERCase", IntegerType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) - // Conflicting field names + // Conflicting non-nullable field names intercept[Throwable] { ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq(StructField("lower", StringType))), + StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } } + + test("merge missing nullable fields from Metastore schema") { + // Standard case: Metastore schema contains additional nullable fields not present + // in the Parquet file schema. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + } } From 99631438c0ec777d6a77974b148dbbd3e890260e Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Sat, 28 Mar 2015 12:32:35 +0000 Subject: [PATCH 066/171] [SPARK-6552][Deploy][Doc]expose start-slave.sh to user and update outdated doc https://issues.apache.org/jira/browse/SPARK-6552 /cc srowen Author: WangTaoTheTonic Closes #5205 from WangTaoTheTonic/SPARK-6552 and squashes the following commits: b02263c [WangTaoTheTonic] use less than rather than less equal f0fa408 [WangTaoTheTonic] expose start-slave.sh --- docs/spark-standalone.md | 3 ++- sbin/start-slave.sh | 10 ++++++++-- sbin/start-slaves.sh | 2 ++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 74d8653a8b845..0eed9adacf123 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./bin/spark-class org.apache.spark.deploy.worker.Worker spark://IP:PORT + ./sbin/start-slave.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). @@ -81,6 +81,7 @@ Once you've set up this file, you can launch or stop your cluster with the follo - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. +- `sbin/start-slave.sh` - Starts a slave instance on the machine the script is executed on. - `sbin/start-all.sh` - Starts both a master and a number of slaves as described above. - `sbin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. - `sbin/stop-slaves.sh` - Stops all slave instances on the machines specified in the `conf/slaves` file. diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 2fc35309f4ca5..c0155384f7395 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -17,8 +17,14 @@ # limitations under the License. # -# Usage: start-slave.sh -# where is like "spark://localhost:7077" +# Starts a slave on the machine this script is executed on. + +usage="Usage: start-slave.sh where is like "spark://localhost:7077" + +if [ $# -lt 2 ]; then + echo $usage + exit 1 +fi sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 76316a3067c93..4356c03657109 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -17,6 +17,8 @@ # limitations under the License. # +# Starts a slave instance on each machine specified in the conf/slaves file. + sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" From f75f633b21faaf911f04aeff847f25749b1ecd89 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 28 Mar 2015 15:08:05 -0700 Subject: [PATCH 067/171] [SPARK-6571][MLLIB] use wrapper in MatrixFactorizationModel.load This fixes `predictAll` after load. jkbradley Author: Xiangrui Meng Closes #5243 from mengxr/SPARK-6571 and squashes the following commits: 82dcaa7 [Xiangrui Meng] use wrapper in MatrixFactorizationModel.load --- .../MatrixFactorizationModelWrapper.scala | 40 +++++++++++++++++++ .../mllib/api/python/PythonMLLibAPI.scala | 18 --------- python/pyspark/mllib/recommendation.py | 8 ++++ 3 files changed, 48 insertions(+), 18 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala new file mode 100644 index 0000000000000..ecd3b16598438 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of MatrixFactorizationModel to provide helper method for Python. + */ +private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + } + + def getProductFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e39156734794c..22fa684fd2895 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -58,7 +58,6 @@ import org.apache.spark.util.Utils */ private[python] class PythonMLLibAPI extends Serializable { - /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. * @param jsc Java SparkContext @@ -346,24 +345,7 @@ private[python] class PythonMLLibAPI extends Serializable { model.predictSoft(data) } - /** - * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python - */ - private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) - extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { - def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = - predict(SerDe.asTupleRDD(userAndProducts.rdd)) - - def getUserFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) - } - - def getProductFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) - } - - } /** * Java stub for Python mllib ALS.train(). This stub returns a handle diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 1a4527b12cef2..b094e50856f70 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -90,6 +90,8 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> sameModel = MatrixFactorizationModel.load(sc, path) >>> sameModel.predict(2,2) 0.43... + >>> sameModel.predictAll(testset).collect() + [Rating(... >>> try: ... os.removedirs(path) ... except OSError: @@ -111,6 +113,12 @@ def userFeatures(self): def productFeatures(self): return self.call("getProductFeatures") + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) + return MatrixFactorizationModel(wrapper) + class ALS(object): From 5eef00d0c6c7cc5448aca7b1c2a2e289a4c43eb0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 28 Mar 2015 23:59:27 -0700 Subject: [PATCH 068/171] [DOC] Improvements to Python docs. Author: Reynold Xin Closes #5238 from rxin/pyspark-docs and squashes the following commits: c285951 [Reynold Xin] Reset deprecation warning. 8c1031e [Reynold Xin] inferSchema dd91b1a [Reynold Xin] [DOC] Improvements to Python docs. --- python/docs/index.rst | 8 ++++++++ python/pyspark/sql/__init__.py | 14 ++++++++------ python/pyspark/sql/dataframe.py | 9 +-------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/python/docs/index.rst b/python/docs/index.rst index d150de9d5c502..f7eede9c3c82a 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -29,6 +29,14 @@ Core classes: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + :class:`pyspark.sql.SQLContext` + + Main entry point for DataFrame and SQL functionality. + + :class:`pyspark.sql.DataFrame` + + A distributed collection of data grouped into named columns. + Indices and tables ================== diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index b9ffd6945ea7e..54a01631d8899 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -19,17 +19,19 @@ public classes of Spark SQL: - L{SQLContext} - Main entry point for SQL functionality. + Main entry point for :class:`DataFrame` and SQL functionality. - L{DataFrame} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, DataFrames also support SQL. + A distributed collection of data grouped into named columns. - L{GroupedData} + Aggregation methods, returned by :func:`DataFrame.groupBy`. - L{Column} - Column is a DataFrame with a single column. + A column expression in a :class:`DataFrame`. - L{Row} - A Row of data returned by a Spark SQL query. + A row of data in a :class:`DataFrame`. - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. + Main entry point for accessing data stored in Apache Hive. + - L{functions} + List of built-in functions available for :class:`DataFrame`. """ from pyspark.sql.context import SQLContext, HiveContext diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d51309f7ef5aa..23c0e63e77812 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -50,13 +50,6 @@ class DataFrame(object): ageCol = people.age - Note that the :class:`Column` type can also be manipulated - through its various functions:: - - # The following creates a new column that increases everybody's age by 10. - people.age + 10 - - A more concrete example:: # To create DataFrame using SQLContext @@ -77,7 +70,7 @@ def __init__(self, jdf, sql_ctx): @property def rdd(self): """ - Return the content of the :class:`DataFrame` as an :class:`RDD` + Return the content of the :class:`DataFrame` as an :class:`pyspark.RDD` of :class:`Row` s. """ if not hasattr(self, '_lazy_rdd'): From 55153f5c14fad10607b44fbb8eebd9636a6bc2e1 Mon Sep 17 00:00:00 2001 From: Brennon York Date: Sun, 29 Mar 2015 12:37:53 +0100 Subject: [PATCH 069/171] [SPARK-4123][Project Infra]: Show new dependencies added in pull requests Starting work on this, but need to find a way to ensure that, after doing a checkout from `apache/master`, we can successfully return to the current checkout. I believe that `git rev-parse HEAD` will get me what I want, but pushing this PR up to test what the Jenkins boxes are seeing. Author: Brennon York Closes #5093 from brennonyork/SPARK-4123 and squashes the following commits: 42e243e [Brennon York] moved starting test output to before pr tests, fixed indentation, changed mvn call to build/mvn dadd941 [Brennon York] reverted assembly pom, put the regular test suite back in play 7aa1dee [Brennon York] set new dendencies into a block, removed the bash debugging flag 0074566 [Brennon York] fixed minor echo issue with quotes e229802 [Brennon York] updated to print the new dependency found 27bb9b5 [Brennon York] changed the assembly pom to test whether the pr test will pick up new deps 5375ad8 [Brennon York] git output to dev null 9bce980 [Brennon York] ensure both gate files exist 8f3c4b4 [Brennon York] updated to reflect the correct pushed in HEAD variable 2bc7b27 [Brennon York] added a pom gate check a18db71 [Brennon York] full test of new deps script ea170de [Brennon York] dont let mvn execute tests f70d8cd [Brennon York] testing mvn with package 62ffd65 [Brennon York] updated dependency output message and changed compile to package given the jenkins failure output 04747e4 [Brennon York] adding simple mvn statement to see if command executes and prints compile output 87f9bea [Brennon York] added -x flag with bash to get insight into what is executing and what isnt 9e87208 [Brennon York] added set blocks to catch any non-zero exit codes and updated output 6b3042b [Brennon York] removed excess git checkout print statements 4077d46 [Brennon York] Merge remote-tracking branch 'upstream/master' into SPARK-4123 2bb5527 [Brennon York] added echo statement so jenkins logs which pr tests are running d027f8f [Brennon York] proper piping of unnecessary stderr and stdout 6e2890d [Brennon York] updated test output newlines d9f6f7f [Brennon York] removed echo bad9a3a [Brennon York] added back the new deps test e9e3ad1 [Brennon York] removed escapes for quotes 97e5cfb [Brennon York] commenting out new deps script 17379a5 [Brennon York] Merge remote-tracking branch 'upstream/master' into SPARK-4123 56f74a8 [Brennon York] updated the unop for ensuring a test is available f2abc8c [Brennon York] removed the git checkout 6912584 [Brennon York] added this_mssg echo output c610d42 [Brennon York] removed the error to dev/null b98f78c [Brennon York] added the removed deps and echo output for jenkins testing 291a8fe [Brennon York] updated location of maven binary 126ce61 [Brennon York] removing new deps test to isolate why jenkins isn't posting messages f8011d8 [Brennon York] minor updates and style changes 63a35c9 [Brennon York] updated new dependencies test dae7ba8 [Brennon York] Capturing output directly from dependency builds 94d3547 [Brennon York] adding the new dependencies script into the test mix 2bca3c3 [Brennon York] added a git checkout 'git rev-parse HEAD' to the end of each pr test ae83b90 [Brennon York] removed jenkins tests to grab some values from the jenkins box 4110993 [Brennon York] beginning work on pr test to add new dependencies --- dev/run-tests-jenkins | 41 ++++++----- dev/tests/pr_new_dependencies.sh | 117 +++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 18 deletions(-) create mode 100755 dev/tests/pr_new_dependencies.sh diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3a937b637e003..f10aa6b59e1af 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -55,13 +55,14 @@ TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout # To write a PR test: #+ * the file must reside within the dev/tests directory #+ * be an executable bash script -#+ * accept two arguments on the command line, the first being the Github PR long commit -#+ hash and the second the Github SHA1 hash +#+ * accept three arguments on the command line, the first being the Github PR long commit +#+ hash, the second the Github SHA1 hash, and the final the current PR hash #+ * and, lastly, return string output to be included in the pr message output that will #+ be posted to Github PR_TESTS=( "pr_merge_ability" "pr_public_classes" + "pr_new_dependencies" ) function post_message () { @@ -146,34 +147,38 @@ function send_archived_logs () { fi } +# post start message +{ + start_message="\ + [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." + + post_message "$start_message" +} + # Environment variable to capture PR test output pr_message="" +# Ensure we save off the current HEAD to revert to +current_pr_head="`git rev-parse HEAD`" # Run pull request tests for t in "${PR_TESTS[@]}"; do this_test="${FWDIR}/dev/tests/${t}.sh" - # Ensure the test is a file and is executable - if [ -x "$this_test" ]; then - echo "ghprb: $ghprbActualCommit sha1: $sha1" - this_mssg="`bash \"${this_test}\" \"${ghprbActualCommit}\" \"${sha1}\" 2>/dev/null`" + # Ensure the test can be found and is a file + if [ -f "${this_test}" ]; then + echo "Running test: $t" + this_mssg="$(bash "${this_test}" "${ghprbActualCommit}" "${sha1}" "${current_pr_head}")" # Check if this is the merge test as we submit that note *before* and *after* # the tests run [ "$t" == "pr_merge_ability" ] && merge_note="${this_mssg}" pr_message="${pr_message}\n${this_mssg}" + # Ensure, after each test, that we're back on the current PR + git checkout -f "${current_pr_head}" &>/dev/null + else + echo "Cannot find test ${this_test}." fi done -# post start message -{ - start_message="\ - [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ - PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - - start_message="${start_message}\n${merge_note}" - - post_message "$start_message" -} - # run tests { timeout "${TESTS_TIMEOUT}" ./dev/run-tests @@ -222,7 +227,7 @@ done PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" - result_message="${result_message}\n${pr_message}" + result_message="${result_message}${pr_message}" post_message "$result_message" } diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh new file mode 100755 index 0000000000000..115a5cd1354f0 --- /dev/null +++ b/dev/tests/pr_new_dependencies.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This script follows the base format for testing pull requests against +# another branch and returning results to be published. More details can be +# found at dev/run-tests-jenkins. +# +# Arg1: The Github Pull Request Actual Commit +#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# Arg2: The SHA1 hash +#+ known as `sha1` in `run-tests-jenkins` +# Arg3: Current PR Commit Hash +#+ the PR hash for the current commit +# + +ghprbActualCommit="$1" +sha1="$2" +current_pr_head="$3" + +MVN_BIN="build/mvn" +CURR_CP_FILE="my-classpath.txt" +MASTER_CP_FILE="master-classpath.txt" + +# First switch over to the master branch +git checkout master &>/dev/null +# Find and copy all pom.xml files into a *.gate file that we can check +# against through various `git` changes +find -name "pom.xml" -exec cp {} {}.gate \; +# Switch back to the current PR +git checkout "${current_pr_head}" &>/dev/null + +# Check if any *.pom files from the current branch are different from the master +difference_q="" +for p in $(find -name "pom.xml"); do + [[ -f "${p}" && -f "${p}.gate" ]] && \ + difference_q="${difference_q}$(diff $p.gate $p)" +done + +# If no pom files were changed we can easily say no new dependencies were added +if [ -z "${difference_q}" ]; then + echo " * This patch does not change any dependencies." +else + # Else we need to manually build spark to determine what, if any, dependencies + # were added into the Spark assembly jar + ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ + sed -n -e '/Building Spark Project Assembly/,$p' | \ + grep --context=1 -m 2 "Dependencies classpath:" | \ + head -n 3 | \ + tail -n 1 | \ + tr ":" "\n" | \ + rev | \ + cut -d "/" -f 1 | \ + rev | \ + sort > ${CURR_CP_FILE} + + # Checkout the master branch to compare against + git checkout master &>/dev/null + + ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ + sed -n -e '/Building Spark Project Assembly/,$p' | \ + grep --context=1 -m 2 "Dependencies classpath:" | \ + head -n 3 | \ + tail -n 1 | \ + tr ":" "\n" | \ + rev | \ + cut -d "/" -f 1 | \ + rev | \ + sort > ${MASTER_CP_FILE} + + DIFF_RESULTS="`diff my-classpath.txt master-classpath.txt`" + + if [ -z "${DIFF_RESULTS}" ]; then + echo " * This patch does not change any dependencies." + else + # Pretty print the new dependencies + added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{print " * \`"$1"\`"}') + removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{print " * \`"$1"\`"}') + added_deps_text=" * This patch **adds the following new dependencies:**\n${added_deps}" + removed_deps_text=" * This patch **removes the following dependencies:**\n${removed_deps}" + + # Construct the final returned message with proper + return_mssg="" + [ -n "${added_deps}" ] && return_mssg="${added_deps_text}" + if [ -n "${removed_deps}" ]; then + if [ -n "${return_mssg}" ]; then + return_mssg="${return_mssg}\n${removed_deps_text}" + else + return_mssg="${removed_deps_text}" + fi + fi + echo "${return_mssg}" + fi + + # Remove the files we've left over + [ -f "${CURR_CP_FILE}" ] && rm -f "${CURR_CP_FILE}" + [ -f "${MASTER_CP_FILE}" ] && rm -f "${MASTER_CP_FILE}" + + # Clean up our mess from the Maven builds just in case + ${MVN_BIN} clean &>/dev/null +fi From e3eb393961051a48ed1cac756ac1928156aa161f Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Sun, 29 Mar 2015 12:40:37 +0100 Subject: [PATCH 070/171] [SPARK-6406] Launch Spark using assembly jar instead of a separate launcher jar Author: Nishkam Ravi Author: nishkamravi2 Author: nravi Closes #5085 from nishkamravi2/master_nravi and squashes the following commits: bad4349 [nishkamravi2] Update Main.java 36a6f87 [Nishkam Ravi] Minor changes and bug fixes b7f4ae7 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 4a45d6a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 458af39 [Nishkam Ravi] Locate the jar using getLocation, obviates the need to pass assembly path as an argument d9658d6 [Nishkam Ravi] Changes for SPARK-6406 ccdc334 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 3faa7a4 [Nishkam Ravi] Launcher library changes (SPARK-6406) 345206a [Nishkam Ravi] spark-class merge Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ac58975 [Nishkam Ravi] spark-class changes 06bfeb0 [nishkamravi2] Update spark-class 35af990 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 32c3ab3 [nishkamravi2] Update AbstractCommandBuilder.java 4bd4489 [nishkamravi2] Update AbstractCommandBuilder.java 746f35b [Nishkam Ravi] "hadoop" string in the assembly name should not be mandatory (everywhere else in spark we mandate spark-assembly*hadoop*.jar) bfe96e0 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi ee902fa [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi d453197 [nishkamravi2] Update NewHadoopRDD.scala 6f41a1d [nishkamravi2] Update NewHadoopRDD.scala 0ce2c32 [nishkamravi2] Update HadoopRDD.scala f7e33c2 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ba1eb8b [Nishkam Ravi] Try-catch block around the two occurrences of removeShutDownHook. Deletion of semi-redundant occurrences of expensive operation inShutDown. 71d0e17 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 494d8c0 [nishkamravi2] Update DiskBlockManager.scala 3c5ddba [nishkamravi2] Update DiskBlockManager.scala f0d12de [Nishkam Ravi] Workaround for IllegalStateException caused by recent changes to BlockManager.stop 79ea8b4 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi b446edc [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 5c9a4cb [nishkamravi2] Update TaskSetManagerSuite.scala 535295a [nishkamravi2] Update TaskSetManager.scala 3e1b616 [Nishkam Ravi] Modify test for maxResultSize 9f6583e [Nishkam Ravi] Changes to maxResultSize code (improve error message and add condition to check if maxResultSize > 0) 5f8f9ed [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 636a9ff [nishkamravi2] Update YarnAllocator.scala 8f76c8b [Nishkam Ravi] Doc change for yarn memory overhead 35daa64 [Nishkam Ravi] Slight change in the doc for yarn memory overhead 5ac2ec1 [Nishkam Ravi] Remove out dac1047 [Nishkam Ravi] Additional documentation for yarn memory overhead issue 42c2c3d [Nishkam Ravi] Additional changes for yarn memory overhead issue 362da5e [Nishkam Ravi] Additional changes for yarn memory overhead c726bd9 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi f00fa31 [Nishkam Ravi] Improving logging for AM memoryOverhead 1cf2d1e [nishkamravi2] Update YarnAllocator.scala ebcde10 [Nishkam Ravi] Modify default YARN memory_overhead-- from an additive constant to a multiplier (redone to resolve merge conflicts) 2e69f11 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi efd688a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- bin/spark-class | 61 +++++++----- bin/spark-class2.cmd | 33 +++---- .../launcher/AbstractCommandBuilder.java | 99 +++++-------------- make-distribution.sh | 1 - 4 files changed, 69 insertions(+), 125 deletions(-) diff --git a/bin/spark-class b/bin/spark-class index e29b234afaf96..c03946d92e2e4 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -40,35 +40,46 @@ else fi fi -# Look for the launcher. In non-release mode, add the compiled classes directly to the classpath -# instead of looking for a jar file. -SPARK_LAUNCHER_CP= -if [ -f $SPARK_HOME/RELEASE ]; then - LAUNCHER_DIR="$SPARK_HOME/lib" - num_jars="$(ls -1 "$LAUNCHER_DIR" | grep "^spark-launcher.*\.jar$" | wc -l)" - if [ "$num_jars" -eq "0" -a -z "$SPARK_LAUNCHER_CP" ]; then - echo "Failed to find Spark launcher in $LAUNCHER_DIR." 1>&2 - echo "You need to build Spark before running this program." 1>&2 - exit 1 - fi +# Find assembly jar +SPARK_ASSEMBLY_JAR= +if [ -f "$SPARK_HOME/RELEASE" ]; then + ASSEMBLY_DIR="$SPARK_HOME/lib" +else + ASSEMBLY_DIR="$SPARK_HOME/assembly/target/scala-$SPARK_SCALA_VERSION" +fi - LAUNCHER_JARS="$(ls -1 "$LAUNCHER_DIR" | grep "^spark-launcher.*\.jar$" || true)" - if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark launcher jars in $LAUNCHER_DIR:" 1>&2 - echo "$LAUNCHER_JARS" 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 - fi +num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" +if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" ]; then + echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 + echo "You need to build Spark before running this program." 1>&2 + exit 1 +fi +ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" +if [ "$num_jars" -gt "1" ]; then + echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 + echo "$ASSEMBLY_JARS" 1>&2 + echo "Please remove all but one jar." 1>&2 + exit 1 +fi - SPARK_LAUNCHER_CP="${LAUNCHER_DIR}/${LAUNCHER_JARS}" +SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" + +# Verify that versions of java used to build the jars and run Spark are compatible +if [ -n "$JAVA_HOME" ]; then + JAR_CMD="$JAVA_HOME/bin/jar" else - LAUNCHER_DIR="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION" - if [ ! -d "$LAUNCHER_DIR/classes" ]; then - echo "Failed to find Spark launcher classes in $LAUNCHER_DIR." 1>&2 - echo "You need to build Spark before running this program." 1>&2 + JAR_CMD="jar" +fi + +if [ $(command -v "$JAR_CMD") ] ; then + jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1) + if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then + echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 + echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 + echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 + echo "or build Spark with Java 6." 1>&2 exit 1 fi - SPARK_LAUNCHER_CP="$LAUNCHER_DIR/classes" fi # The launcher library will print arguments separated by a NULL character, to allow arguments with @@ -77,7 +88,7 @@ fi CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") -done < <("$RUNNER" -cp "$SPARK_LAUNCHER_CP" org.apache.spark.launcher.Main "$@") +done < <("$RUNNER" -cp "$SPARK_ASSEMBLY_JAR" org.apache.spark.launcher.Main "$@") if [ "${CMD[0]}" = "usage" ]; then "${CMD[@]}" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 37d22215a0e7e..4ce727bc99128 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -29,31 +29,20 @@ if "x%1"=="x" ( exit /b 1 ) -set LAUNCHER_CP=0 -if exist %SPARK_HOME%\RELEASE goto find_release_launcher +rem Find assembly jar +set SPARK_ASSEMBLY_JAR=0 -rem Look for the Spark launcher in both Scala build directories. The launcher doesn't use Scala so -rem it doesn't really matter which one is picked up. Add the compiled classes directly to the -rem classpath instead of looking for a jar file, since it's very common for people using sbt to use -rem the "assembly" target instead of "package". -set LAUNCHER_CLASSES=%SPARK_HOME%\launcher\target\scala-2.10\classes -if exist %LAUNCHER_CLASSES% ( - set LAUNCHER_CP=%LAUNCHER_CLASSES% +if exist "%SPARK_HOME%\RELEASE" ( + set ASSEMBLY_DIR=%SPARK_HOME%\lib +) else ( + set ASSEMBLY_DIR=%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION% ) -set LAUNCHER_CLASSES=%SPARK_HOME%\launcher\target\scala-2.11\classes -if exist %LAUNCHER_CLASSES% ( - set LAUNCHER_CP=%LAUNCHER_CLASSES% -) -goto check_launcher -:find_release_launcher -for %%d in (%SPARK_HOME%\lib\spark-launcher*.jar) do ( - set LAUNCHER_CP=%%d +for %%d in (%ASSEMBLY_DIR%\spark-assembly*hadoop*.jar) do ( + set SPARK_ASSEMBLY_JAR=%%d ) - -:check_launcher -if "%LAUNCHER_CP%"=="0" ( - echo Failed to find Spark launcher JAR. +if "%SPARK_ASSEMBLY_JAR%"=="0" ( + echo Failed to find Spark assembly JAR. echo You need to build Spark before running this program. exit /b 1 ) @@ -64,7 +53,7 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCHER_CP% org.apache.spark.launcher.Main %*"') do ( +for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %SPARK_ASSEMBLY_JAR% org.apache.spark.launcher.Main %*"') do ( set SPARK_CMD=%%i ) %SPARK_CMD% diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 2da5f7278729e..d8279145d8e90 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -86,10 +86,14 @@ public AbstractCommandBuilder() { */ List buildJavaCommand(String extraClassPath) throws IOException { List cmd = new ArrayList(); - if (javaHome == null) { - cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java")); - } else { + String envJavaHome; + + if (javaHome != null) { cmd.add(join(File.separator, javaHome, "bin", "java")); + } else if ((envJavaHome = System.getenv("JAVA_HOME")) != null) { + cmd.add(join(File.separator, envJavaHome, "bin", "java")); + } else { + cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java")); } // Load extra JAVA_OPTS from conf/java-opts, if it exists. @@ -182,59 +186,25 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - String assembly = findAssembly(); + final String assembly = AbstractCommandBuilder.class.getProtectionDomain().getCodeSource(). + getLocation().getPath(); addToClassPath(cp, assembly); - // When Hive support is needed, Datanucleus jars must be included on the classpath. Datanucleus - // jars do not work if only included in the uber jar as plugin.xml metadata is lost. Both sbt - // and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is built - // with Hive, so first check if the datanucleus jars exist, and then ensure the current Spark - // assembly is built for Hive, before actually populating the CLASSPATH with the jars. - // - // This block also serves as a check for SPARK-1703, when the assembly jar is built with - // Java 7 and ends up with too many files, causing issues with other JDK versions. - boolean needsDataNucleus = false; - JarFile assemblyJar = null; - try { - assemblyJar = new JarFile(assembly); - needsDataNucleus = assemblyJar.getEntry("org/apache/hadoop/hive/ql/exec/") != null; - } catch (IOException ioe) { - if (ioe.getMessage().indexOf("invalid CEN header") >= 0) { - System.err.println( - "Loading Spark jar failed.\n" + - "This is likely because Spark was compiled with Java 7 and run\n" + - "with Java 6 (see SPARK-1703). Please use Java 7 to run Spark\n" + - "or build Spark with Java 6."); - System.exit(1); - } else { - throw ioe; - } - } finally { - if (assemblyJar != null) { - try { - assemblyJar.close(); - } catch (IOException e) { - // Ignore. - } - } + // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only + // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate + // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + } else { + libdir = new File(sparkHome, "lib_managed/jars"); } - if (needsDataNucleus) { - System.err.println("Spark assembly has been built with Hive, including Datanucleus jars " + - "in classpath."); - File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); - } else { - libdir = new File(sparkHome, "lib_managed/jars"); - } - - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - for (File jar : libdir.listFiles()) { - if (jar.getName().startsWith("datanucleus-")) { - addToClassPath(cp, jar.getAbsolutePath()); - } + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + for (File jar : libdir.listFiles()) { + if (jar.getName().startsWith("datanucleus-")) { + addToClassPath(cp, jar.getAbsolutePath()); } } @@ -270,7 +240,6 @@ String getScalaVersion() { if (scala != null) { return scala; } - String sparkHome = getSparkHome(); File scala210 = new File(sparkHome, "assembly/target/scala-2.10"); File scala211 = new File(sparkHome, "assembly/target/scala-2.11"); @@ -330,30 +299,6 @@ String getenv(String key) { return firstNonEmpty(childEnv.get(key), System.getenv(key)); } - private String findAssembly() { - String sparkHome = getSparkHome(); - File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - } else { - libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); - } - - final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); - FileFilter filter = new FileFilter() { - @Override - public boolean accept(File file) { - return file.isFile() && re.matcher(file.getName()).matches(); - } - }; - File[] assemblies = libdir.listFiles(filter); - checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); - checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); - return assemblies[0].getAbsolutePath(); - } - private String getConfDir() { String confDir = getenv("SPARK_CONF_DIR"); return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); diff --git a/make-distribution.sh b/make-distribution.sh index 9ed1abfe8c598..738a9c4d69601 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -199,7 +199,6 @@ echo "Build flags: $@" >> "$DISTDIR/RELEASE" # Copy jars cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$SPARK_HOME"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" -cp "$SPARK_HOME"/launcher/target/spark-launcher_$SCALA_VERSION-$VERSION.jar "$DISTDIR/lib/" # This will fail if the -Pyarn profile is not provided # In this case, silence the error and ignore the return code of this command cp "$SPARK_HOME"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : From 52ece26b8fb9769f6ed9167e3dffc8b1d7c61b02 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Sun, 29 Mar 2015 12:43:30 +0100 Subject: [PATCH 071/171] [SPARK-6558] Utils.getCurrentUserName returns the full principal name instead of login name Utils.getCurrentUserName returns UserGroupInformation.getCurrentUser().getUserName() when SPARK_USER isn't set. It should return UserGroupInformation.getCurrentUser().getShortUserName() getUserName() returns the users full principal name (ie user1CORP.COM). getShortUserName() returns just the users login name (user1). This just happens to work on YARN because the Client code sets: env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() Author: Thomas Graves Closes #5229 from tgravescs/SPARK-6558 and squashes the following commits: 24830bf [Thomas Graves] Utils.getCurrentUserName returns the full principal name instead of login name --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0b5a914e7dbbf..bb8bd1015668a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2055,7 +2055,7 @@ private[spark] object Utils extends Logging { */ def getCurrentUserName(): String = { Option(System.getenv("SPARK_USER")) - .getOrElse(UserGroupInformation.getCurrentUser().getUserName()) + .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } } From 0e2753ff14e0d3f2433272c13ce26f67dc89767f Mon Sep 17 00:00:00 2001 From: "June.He" Date: Sun, 29 Mar 2015 12:47:22 +0100 Subject: [PATCH 072/171] [SPARK-6585][Tests]Fix FileServerSuite testcase in some Env. Change FileServerSuite.test("HttpFileServer should not work with SSL when the server is untrusted") catch SSLException Author: June.He Closes #5239 from sisihj/SPARK-6585 and squashes the following commits: cb19ae3 [June.He] Change FileServerSuite.test("HttpFileServer should not work with SSL when the server is untrusted") catch SSLException --- core/src/test/scala/org/apache/spark/FileServerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 5fdf6bc2777e3..a69e9b761f9a7 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io._ import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} -import javax.net.ssl.SSLHandshakeException +import javax.net.ssl.SSLException import com.google.common.io.ByteStreams import org.apache.commons.io.{FileUtils, IOUtils} @@ -228,7 +228,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { try { server.initialize() - intercept[SSLHandshakeException] { + intercept[SSLException] { fileTransferTest(server) } } finally { From a8d53afb4e119788fa0d9dd6b3e3ca94cea98581 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 29 Mar 2015 21:25:09 -0700 Subject: [PATCH 073/171] [SPARK-5124][Core] A standard RPC interface and an Akka implementation This PR added a standard internal RPC interface for Spark and an Akka implementation. See [the design document](https://issues.apache.org/jira/secure/attachment/12698710/Pluggable%20RPC%20-%20draft%202.pdf) for more details. I will split the whole work into multiple PRs to make it easier for code review. This is the first PR and avoid to touch too many files. Author: zsxwing Closes #4588 from zsxwing/rpc-part1 and squashes the following commits: fe3df4c [zsxwing] Move registerEndpoint and use actorSystem.dispatcher in asyncSetupEndpointRefByURI f6f3287 [zsxwing] Remove RpcEndpointRef.toURI 8bd1097 [zsxwing] Fix docs and the code style f459380 [zsxwing] Add RpcAddress.fromURI and rename urls to uris b221398 [zsxwing] Move send methods above ask methods 15cfd7b [zsxwing] Merge branch 'master' into rpc-part1 9ffa997 [zsxwing] Fix MiMa tests 78a1733 [zsxwing] Merge remote-tracking branch 'origin/master' into rpc-part1 385b9c3 [zsxwing] Fix the code style and add docs 2cc3f78 [zsxwing] Add an asynchronous version of setupEndpointRefByUrl e8dfec3 [zsxwing] Remove 'sendWithReply(message: Any, sender: RpcEndpointRef): Unit' 08564ae [zsxwing] Add RpcEnvFactory to create RpcEnv e5df4ca [zsxwing] Handle AkkaFailure(e) in Actor ec7c5b0 [zsxwing] Fix docs 7fc95e1 [zsxwing] Implement askWithReply in RpcEndpointRef 9288406 [zsxwing] Document thread-safety for setupThreadSafeEndpoint 3007c09 [zsxwing] Move setupDriverEndpointRef to RpcUtils and rename to makeDriverRef c425022 [zsxwing] Fix the code style 5f87700 [zsxwing] Move the logical of processing message to a private function 3e56123 [zsxwing] Use lazy to eliminate CountDownLatch 07f128f [zsxwing] Remove ActionScheduler.scala 4d34191 [zsxwing] Remove scheduler from RpcEnv 7cdd95e [zsxwing] Add docs for RpcEnv 51e6667 [zsxwing] Add 'sender' to RpcCallContext and rename the parameter of receiveAndReply to 'context' ffc1280 [zsxwing] Rename 'fail' to 'sendFailure' and other minor code style changes 28e6d0f [zsxwing] Add onXXX for network events and remove the companion objects of network events 3751c97 [zsxwing] Rename RpcResponse to RpcCallContext fe7d1ff [zsxwing] Add explicit reply in rpc 7b9e0c9 [zsxwing] Fix the indentation 04a106e [zsxwing] Remove NopCancellable and add a const NOP in object SettableCancellable 2a579f4 [zsxwing] Remove RpcEnv.systemName 155b987 [zsxwing] Change newURI to uriOf and add some comments 45b2317 [zsxwing] A standard RPC interface and An Akka implementation --- .../scala/org/apache/spark/SparkEnv.scala | 42 +- .../spark/deploy/worker/DriverWrapper.scala | 11 +- .../spark/deploy/worker/WorkerWatcher.scala | 59 +- .../CoarseGrainedExecutorBackend.scala | 2 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 429 ++++++++++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 318 +++++++++++ .../scheduler/OutputCommitCoordinator.scala | 37 +- .../org/apache/spark/util/AkkaUtils.scala | 2 +- .../org/apache/spark/util/RpcUtils.scala | 35 ++ .../deploy/worker/WorkerWatcherSuite.scala | 38 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 525 ++++++++++++++++++ .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 50 ++ project/MimaExcludes.scala | 4 +- 13 files changed, 1466 insertions(+), 86 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala create mode 100644 core/src/main/scala/org/apache/spark/util/RpcUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 2a0c7e756dd3a..4a2ed82a40dec 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,12 +34,14 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} -import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} /** * :: DeveloperApi :: @@ -54,7 +56,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val actorSystem: ActorSystem, + private[spark] val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -71,6 +73,9 @@ class SparkEnv ( val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { + // TODO Remove actorSystem + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -91,7 +96,8 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - actorSystem.shutdown() + rpcEnv.shutdown() + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. @@ -236,16 +242,15 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) // Create the ActorSystem for Akka and get the port it binds to. - val (actorSystem, boundPort) = { - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) - } + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) } else { - conf.set("spark.executor.port", boundPort.toString) + conf.set("spark.executor.port", rpcEnv.address.port.toString) } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -290,6 +295,15 @@ object SparkEnv extends Logging { } } + def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + if (isDriver) { + logInfo("Registering " + name) + rpcEnv.setupEndpoint(name, endpointCreator) + } else { + RpcUtils.makeDriverRef(name, conf, rpcEnv) + } + } + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { @@ -377,13 +391,13 @@ object SparkEnv extends Logging { val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf) } - val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator", - new OutputCommitCoordinatorActor(outputCommitCoordinator)) - outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor) + val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", + new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) + outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) val envInstance = new SparkEnv( executorId, - actorSystem, + rpcEnv, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index deef6ef9043c6..d1a12b01e78f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,10 +19,9 @@ package org.apache.spark.deploy.worker import java.io.File -import akka.actor._ - import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -39,9 +38,9 @@ object DriverWrapper { */ case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + val rpcEnv = RpcEnv.create("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) val currentLoader = Thread.currentThread.getContextClassLoader val userJarUrl = new File(userJar).toURI().toURL() @@ -58,7 +57,7 @@ object DriverWrapper { val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) - actorSystem.shutdown() + rpcEnv.shutdown() case _ => System.err.println("Usage: DriverWrapper [options]") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index e0790274d7d3e..83fb991891a41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,58 +17,63 @@ package org.apache.spark.deploy.worker -import akka.actor.{Actor, Address, AddressFromURIString} -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc._ /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) - extends Actor with ActorLogReceive with Logging { - - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) + extends RpcEndpoint with Logging { + override def onStart() { logInfo(s"Connecting to worker $workerUrl") - val worker = context.actorSelection(workerUrl) - worker ! SendHeartbeat // need to send a message here to initiate connection + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) + } } // Used to avoid shutting down JVM during tests + // In the normal case, exitNonZero will call `System.exit(-1)` to shutdown the JVM. In the unit + // test, the user should call `setTesting(true)` so that `exitNonZero` will set `isShutDown` to + // true rather than calling `System.exit`. The user can check `isShutDown` to know if + // `exitNonZero` is called. private[deploy] var isShutDown = false private[deploy] def setTesting(testing: Boolean) = isTesting = testing private var isTesting = false // Lets us filter events only from the worker's actor system - private val expectedHostPort = AddressFromURIString(workerUrl).hostPort - private def isWorker(address: Address) = address.hostPort == expectedHostPort + private val expectedAddress = RpcAddress.fromURIString(workerUrl) + private def isWorker(address: RpcAddress) = expectedAddress == address private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => - logInfo(s"Successfully connected to $workerUrl") + override def receive: PartialFunction[Any, Unit] = { + case e => logWarning(s"Received unexpected message: $e") + } - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress) => - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + } - case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { // This log message will never be seen logError(s"Lost connection to worker actor $workerUrl. Exiting.") exitNonZero() + } + } - case e: AssociationEvent => - // pass through association events relating to other remote actor systems - - case e => logWarning(s"Received unexpected actor system event: $e") + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b5205d4e997ae..900e678ee02ef 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -169,7 +169,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverUrl, executorId, sparkHostPort, cores, userClassPath, env), name = "Executor") workerUrl.foreach { url => - env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala new file mode 100644 index 0000000000000..7985941d949c0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.net.URI + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} +import org.apache.spark.util.{AkkaUtils, Utils} + +/** + * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to + * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote + * nodes, and deliver them to corresponding [[RpcEndpoint]]s. + * + * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. + */ +private[spark] abstract class RpcEnv(conf: SparkConf) { + + private[spark] val defaultLookupTimeout = AkkaUtils.lookupTimeout(conf) + + /** + * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement + * [[RpcEndpoint.self]]. + * + * Note: This method won't return null. `IllegalArgumentException` will be thrown if calling this + * on a non-existent endpoint. + */ + private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Return the address that [[RpcEnv]] is listening to. + */ + def address: RpcAddress + + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] does not + * guarantee thread-safety. + */ + def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should + * make sure thread-safely sending messages to [[RpcEndpoint]]. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[RpcEndpoint]]. In the other words, changes to internal fields of a [[RpcEndpoint]] + * are visible when processing the next message, and fields in the [[RpcEndpoint]] need not be + * volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same [[RpcEndpoint]] + * for different messages. + */ + def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. + */ + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. + */ + def setupEndpointRefByURI(uri: String): RpcEndpointRef = { + Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` + * asynchronously. + */ + def asyncSetupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { + asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * This is a blocking action. + */ + def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Stop [[RpcEndpoint]] specified by `endpoint`. + */ + def stop(endpoint: RpcEndpointRef): Unit + + /** + * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully, + * call [[awaitTermination()]] straight after [[shutdown()]]. + */ + def shutdown(): Unit + + /** + * Wait until [[RpcEnv]] exits. + * + * TODO do we need a timeout parameter? + */ + def awaitTermination(): Unit + + /** + * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of + * creating it manually because different [[RpcEnv]] may have different formats. + */ + def uriOf(systemName: String, address: RpcAddress, endpointName: String): String +} + +private[spark] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager) + +/** + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. + */ +private[spark] object RpcEnv { + + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). + newInstance().asInstanceOf[RpcEnvFactory] + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + getRpcEnvFactory(conf).create(config) + } + +} + +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * [[RpcEnv.setupThreadSafeEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. + */ + final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + */ + def receive: PartialFunction[Any, Unit] = { + case _ => throw new SparkException(self + " does not implement 'receive'") + } + + /** + * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. + */ + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) + } + + /** + * Call onError when any exception is thrown during handling messages. + * + * @param cause + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * A convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(self) + } + } +} + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) + private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) + private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = sendWithReply[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } + +} + +/** + * Represent a host with a port + */ +private[spark] case class RpcAddress(host: String, port: Int) { + // TODO do we need to add the type of RpcEnv in the address? + + val hostPort: String = host + ":" + port + + override val toString: String = hostPort +} + +private[spark] object RpcAddress { + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURI(uri: URI): RpcAddress = { + RpcAddress(uri.getHost, uri.getPort) + } + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURIString(uri: String): RpcAddress = { + fromURI(new java.net.URI(uri)) + } + + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. + */ +private[spark] trait RpcCallContext { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Report a failure to the sender. + */ + def sendFailure(e: Throwable): Unit + + /** + * The sender of this message. + */ + def sender: RpcEndpointRef +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala new file mode 100644 index 0000000000000..769d59b7b3343 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import java.net.URI +import java.util.concurrent.ConcurrentHashMap + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.pattern.{ask => akkaAsk} +import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ActorLogReceive, AkkaUtils} + +/** + * A RpcEnv implementation based on Akka. + * + * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and + * remove Akka from the dependencies. + * + * @param actorSystem + * @param conf + * @param boundPort + */ +private[spark] class AkkaRpcEnv private[akka] ( + val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + extends RpcEnv(conf) with Logging { + + private val defaultAddress: RpcAddress = { + val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + // In some test case, ActorSystem doesn't bind to any address. + // So just use some default value since they are only some unit tests + RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) + } + + override val address: RpcAddress = defaultAddress + + /** + * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make + * [[RpcEndpoint.self]] work. + */ + private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + /** + * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` + */ + private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { + endpointToRef.put(endpoint, endpointRef) + refToEndpoint.put(endpointRef, endpoint) + } + + private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { + val endpoint = refToEndpoint.remove(endpointRef) + if (endpoint != null) { + endpointToRef.remove(endpoint) + } + } + + /** + * Retrieve the [[RpcEndpointRef]] of `endpoint`. + */ + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { + val endpointRef = endpointToRef.get(endpoint) + require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}") + endpointRef + } + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + setupThreadSafeEndpoint(name, endpoint) + } + + override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + @volatile var endpointRef: AkkaRpcEndpointRef = null + // Use lazy because the Actor needs to use `endpointRef`. + // So `actorRef` should be created after assigning `endpointRef`. + lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + + assert(endpointRef != null) + + override def preStart(): Unit = { + // Listen for remote client network events + context.system.eventStream.subscribe(self, classOf[AssociationEvent]) + safelyCall(endpoint) { + endpoint.onStart() + } + } + + override def receiveWithLogging: Receive = { + case AssociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case DisassociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + safelyCall(endpoint) { + endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + } + + case e: AssociationEvent => + // TODO ignore? + + case m: AkkaMessage => + logDebug(s"Received RPC message: $m") + safelyCall(endpoint) { + processMessage(endpoint, m, sender) + } + + case AkkaFailure(e) => + safelyCall(endpoint) { + throw e + } + + case message: Any => { + logWarning(s"Unknown message: $message") + } + + } + + override def postStop(): Unit = { + unregisterEndpoint(endpoint.self) + safelyCall(endpoint) { + endpoint.onStop() + } + } + + }), name = name) + endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) + registerEndpoint(endpoint, endpointRef) + // Now actorRef can be created safely + endpointRef.init() + endpointRef + } + + private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { + val message = m.message + val needReply = m.needReply + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(new RpcCallContext { + override def sendFailure(e: Throwable): Unit = { + _sender ! AkkaFailure(e) + } + + override def reply(response: Any): Unit = { + _sender ! AkkaMessage(response, false) + } + + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + }) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](message, { message => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + _sender ! AkkaFailure(e) + } else { + throw e + } + } + } + + /** + * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will + * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. + */ + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } + } + } + } + + private def akkaAddressToRpcAddress(address: Address): RpcAddress = { + RpcAddress(address.host.getOrElse(defaultAddress.host), + address.port.getOrElse(defaultAddress.port)) + } + + override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + import actorSystem.dispatcher + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { + AkkaUtils.address( + AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) + } + + override def shutdown(): Unit = { + actorSystem.shutdown() + } + + override def stop(endpoint: RpcEndpointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) + } + + override def awaitTermination(): Unit = { + actorSystem.awaitTermination() + } + + override def toString: String = s"${getClass.getSimpleName}($actorSystem)" +} + +private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + config.name, config.host, config.port, config.conf, config.securityManager) + new AkkaRpcEnv(actorSystem, config.conf, boundPort) + } +} + +private[akka] class AkkaRpcEndpointRef( + @transient defaultAddress: RpcAddress, + @transient _actorRef: => ActorRef, + @transient conf: SparkConf, + @transient initInConstructor: Boolean = true) + extends RpcEndpointRef(conf) with Logging { + + lazy val actorRef = _actorRef + + override lazy val address: RpcAddress = { + val akkaAddress = actorRef.path.address + RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), + akkaAddress.port.getOrElse(defaultAddress.port)) + } + + override lazy val name: String = actorRef.path.name + + private[akka] def init(): Unit = { + // Initialize the lazy vals + actorRef + address + name + } + + if (initInConstructor) { + init() + } + + override def send(message: Any): Unit = { + actorRef ! AkkaMessage(message, false) + } + + override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + import scala.concurrent.ExecutionContext.Implicits.global + actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + case msg @ AkkaMessage(message, reply) => + if (reply) { + logError(s"Receive $msg but the sender cannot reply") + Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) + } else { + Future.successful(message) + } + case AkkaFailure(e) => + Future.failed(e) + }.mapTo[T] + } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" + +} + +/** + * A wrapper to `message` so that the receiver knows if the sender expects a reply. + * @param message + * @param needReply if the sender expects a reply message + */ +private[akka] case class AkkaMessage(message: Any, needReply: Boolean) + +/** + * A reply with the failure error from the receiver to the sender + */ +private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index a3caa9f000c89..f748f394d1347 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable -import akka.actor.{ActorRef, Actor} - import org.apache.spark._ -import org.apache.spark.util.{AkkaUtils, ActorLogReceive} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} private sealed trait OutputCommitCoordinationMessage extends Serializable @@ -34,8 +32,8 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem * policy. * * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is - * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit - * output will be forwarded to the driver's OutputCommitCoordinator. + * configured with a reference to the driver's OutputCommitCoordinatorEndpoint, so requests to + * commit output will be forwarded to the driver's OutputCommitCoordinator. * * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) * for an extensive design discussion. @@ -43,10 +41,7 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { // Initialized by SparkEnv - var coordinatorActor: Option[ActorRef] = None - private val timeout = AkkaUtils.askTimeout(conf) - private val maxAttempts = AkkaUtils.numRetries(conf) - private val retryInterval = AkkaUtils.retryWaitMs(conf) + var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int private type PartitionId = Long @@ -81,9 +76,9 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { partition: PartitionId, attempt: TaskAttemptId): Boolean = { val msg = AskPermissionToCommitOutput(stage, partition, attempt) - coordinatorActor match { - case Some(actor) => - AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout) + coordinatorRef match { + case Some(endpointRef) => + endpointRef.askWithReply[Boolean](msg) case None => logError( "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") @@ -125,8 +120,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { } def stop(): Unit = synchronized { - coordinatorActor.foreach(_ ! StopCoordinator) - coordinatorActor = None + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None authorizedCommittersByStage.clear() } @@ -157,16 +152,18 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private[spark] object OutputCommitCoordinator { // This actor is used only for RPC - class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) - extends Actor with ActorLogReceive with Logging { + private[spark] class OutputCommitCoordinatorEndpoint( + override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) + extends RpcEndpoint with Logging { - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => - sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) + context.reply( + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) case StopCoordinator => logInfo("OutputCommitCoordinator stopped!") - context.stop(self) - sender ! true + context.reply(true) + stop() } } } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 48a6ede05e17b..6c2c5261306e7 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -179,7 +179,7 @@ private[spark] object AkkaUtils extends Logging { message: Any, actor: ActorRef, maxAttempts: Int, - retryInterval: Int, + retryInterval: Long, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala new file mode 100644 index 0000000000000..6665b17c3d5df --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} + +object RpcUtils { + + /** + * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + */ + def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { + val driverActorSystemName = SparkEnv.driverActorSystemName + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 5e538d6fab2a1..6a6f29dd613cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,32 +17,38 @@ package org.apache.spark.deploy.worker -import akka.actor.{ActorSystem, AddressFromURIString, Props} -import akka.testkit.TestActorRef -import akka.remote.DisassociatedEvent +import akka.actor.AddressFromURIString +import org.apache.spark.SparkConf +import org.apache.spark.SecurityManager +import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) - assert(actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected( + RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + assert(workerWatcher.isShutDown) + rpcEnv.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" - val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" + val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) - assert(!actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + assert(!workerWatcher.isShutDown) + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala new file mode 100644 index 0000000000000..e07bdb9637575 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -0,0 +1,525 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} + +import scala.collection.mutable +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{SparkException, SparkConf} + +/** + * Common tests for an RpcEnv implementation. + */ +abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { + + var env: RpcEnv = _ + + override def beforeAll(): Unit = { + val conf = new SparkConf() + env = createRpcEnv(conf, "local", 12345) + } + + override def afterAll(): Unit = { + if(env != null) { + env.shutdown() + } + } + + def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv + + test("send a message locally") { + @volatile var message: String = null + val rpcEndpointRef = env.setupEndpoint("send-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case msg: String => message = msg + } + }) + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } + + test("send a message remotely") { + @volatile var message: String = null + // Set up a RpcEndpoint using env + env.setupEndpoint("send-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case msg: String => message = msg + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote" ,13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") + try { + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("send a RpcEndpointRef") { + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case "Hello" => context.reply(self) + case "Echo" => context.reply("Echo") + } + } + val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) + + val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askWithReply[String]("Echo") + assert("Echo" === reply) + } + + test("ask a message locally") { + val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => { + context.reply(msg) + } + } + }) + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } + + test("ask a message remotely") { + env.setupEndpoint("ask-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => { + context.reply(msg) + } + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") + try { + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("ask a message timeout") { + env.setupEndpoint("ask-timeout", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => { + Thread.sleep(100) + context.reply(msg) + } + } + }) + + val conf = new SparkConf() + conf.set("spark.akka.retry.wait", "0") + conf.set("spark.akka.num.retries", "1") + val anotherEnv = createRpcEnv(conf, "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") + try { + val e = intercept[Exception] { + rpcEndpointRef.askWithReply[String]("hello", 1 millis) + } + assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("onStart and onStop") { + val stopLatch = new CountDownLatch(1) + val calledMethods = mutable.ArrayBuffer[String]() + + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + calledMethods += "start" + } + + override def receive = { + case msg: String => + } + + override def onStop(): Unit = { + calledMethods += "stop" + stopLatch.countDown() + } + } + val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint) + env.stop(rpcEndpointRef) + stopLatch.await(10, TimeUnit.SECONDS) + assert(List("start", "stop") === calledMethods) + } + + test("onError: error in onStart") { + @volatile var e: Throwable = null + env.setupEndpoint("onError-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + throw new RuntimeException("Oops!") + } + + override def receive = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in onStop") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + + override def onStop(): Unit = { + throw new RuntimeException("Oops!") + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in receive") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => throw new RuntimeException("Oops!") + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("self: call in onStart") { + @volatile var callSelfSuccessfully = false + + env.setupEndpoint("self-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + self + callSelfSuccessfully = true + } + + override def receive = { + case m => + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStart` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in receive") { + @volatile var callSelfSuccessfully = false + + val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => { + self + callSelfSuccessfully = true + } + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `receive` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in onStop") { + @volatile var e: Throwable = null + + val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => + } + + override def onStop(): Unit = { + self + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStop` is invalid + assert(e != null) + assert(e.getMessage.contains("Cannot find RpcEndpointRef")) + } + } + + test("call receive in sequence") { + // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it + for(i <- 0 until 100) { + @volatile var result = 0 + val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => result += 1 + } + + }) + + (0 until 10) foreach { _ => + new Thread { + override def run() { + (0 until 100) foreach { _ => + endpointRef.send("Hello") + } + } + }.start() + } + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(result == 1000) + } + + env.stop(endpointRef) + } + } + + test("stop(RpcEndpointRef) reentrant") { + @volatile var onStopCount = 0 + val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => + } + + override def onStop(): Unit = { + onStopCount += 1 + } + }) + + env.stop(endpointRef) + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(5 millis)) { + // Calling stop twice should only trigger onStop once. + assert(onStopCount == 1) + } + } + + test("sendWithReply") { + val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case m => context.reply("ack") + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely") { + env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case m => context.reply("ack") + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("sendWithReply: error") { + val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case m => context.sendFailure(new SparkException("Oops")) + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely error") { + env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => context.sendFailure(new SparkException("Oops")) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-remotely-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("network events") { + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + env.setupThreadSafeEndpoint("network-events", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case "hello" => + case m => events += "receive" -> m + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List(("onConnected", remoteAddress))) + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List( + ("onConnected", remoteAddress), + ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress))) + } + } +} + +case object Start + +case class Ping(id: Int) + +case class Pong(id: Int) diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala new file mode 100644 index 0000000000000..58214c0637235 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import org.apache.spark.rpc._ +import org.apache.spark.{SecurityManager, SparkConf} + +class AkkaRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) + } + + test("setupEndpointRef: systemName, address, endpointName") { + val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case _ => + } + }) + val conf = new SparkConf() + val newRpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + try { + val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") + assert("akka.tcp://local@localhost:12345/user/test_endpoint" === + newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) + } finally { + newRpcEnv.shutdown() + } + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b9f40046e15a2..efd59a7e5470f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -50,7 +50,9 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast") + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorActor") ) ++ Seq( // SPARK-6510 Add a Graph#minus method acting as Set#difference ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") From 01dc9f50d1aae1f24021062291d73182a2622f2c Mon Sep 17 00:00:00 2001 From: Li Zhihui Date: Sun, 29 Mar 2015 21:30:37 -0700 Subject: [PATCH 074/171] Fix string interpolator error in HeartbeatReceiver Error log before fixed 15/03/29 10:07:25 ERROR YarnScheduler: Lost an executor 24 (already removed): Executor heartbeat timed out after ${now - lastSeenMs} ms Author: Li Zhihui Closes #5255 from li-zhihui/fixstringinterpolator and squashes the following commits: c93f2b7 [Li Zhihui] Fix string interpolator error in HeartbeatReceiver --- core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 548dcb93c3358..8435e1ea2611c 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -89,7 +89,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule logWarning(s"Removing executor $executorId with no recent heartbeats: " + s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms") scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + - "timed out after ${now - lastSeenMs} ms")) + s"timed out after ${now - lastSeenMs} ms")) if (sc.supportDynamicAllocation) { sc.killExecutor(executorId) } From 17b13c53ec9d8579a7fb801ab781bce43809db6a Mon Sep 17 00:00:00 2001 From: Eran Medan Date: Mon, 30 Mar 2015 00:02:52 -0700 Subject: [PATCH 075/171] [spark-sql] a better exception message than "scala.MatchError" for unsupported types in Schema creation Currently if trying to register an RDD (or DataFrame in 1.3) as a table that has types that have no supported Schema representation (e.g. type "Any") - it would throw a match error. e.g. scala.MatchError: Any (of class scala.reflect.internal.Types$ClassNoArgsTypeRef) This fix is just to have a nicer error message than a MatchError Author: Eran Medan Closes #5235 from eranation/patch-2 and squashes the following commits: af4b1a2 [Eran Medan] Line should be under 100 chars 0c69e9d [Eran Medan] Change from sys.error UnsupportedOperationException 524be86 [Eran Medan] better exception than scala.MatchError: Any --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d6126c24fc50d..2220970085462 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -179,6 +179,8 @@ trait ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case other => + throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } From de6733036e060e18b0d1f21f9365bda81132a1a2 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Mon, 30 Mar 2015 11:41:43 +0100 Subject: [PATCH 076/171] [SPARK-6596] fix the instruction on building scaladoc In README.md under docs/ directory, it says that > You can build just the Spark scaladoc by running build/sbt doc from the SPARK_PROJECT_ROOT directory. I guess the right approach is build/sbt unidoc Author: CodingCat Closes #5253 from CodingCat/SPARK-6596 and squashes the following commits: af379ed [CodingCat] fix the instruction on building scaladoc --- docs/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/README.md b/docs/README.md index 8a54724c4beae..3773ea25c8b67 100644 --- a/docs/README.md +++ b/docs/README.md @@ -60,7 +60,7 @@ We use Sphinx to generate Python API docs, so you will need to install it by run ## API Docs (Scaladoc and Sphinx) -You can build just the Spark scaladoc by running `build/sbt doc` from the SPARK_PROJECT_ROOT directory. +You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as @@ -68,7 +68,7 @@ public in `__init__.py`. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a -jekyll plugin to run `build/sbt doc` before building the site so if you haven't run it (recently) it +jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). From 4bdfb7bab3b9d20167571d9b6888a2a44d9d43fc Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Mon, 30 Mar 2015 11:52:02 +0100 Subject: [PATCH 077/171] [SPARK-5750][SPARK-3441][SPARK-5836][CORE] Added documentation explaining shuffle I've updated the Spark Programming Guide to add a section on the shuffle operation providing some background on what it does. I've also addressed some of its performance impacts. I've included documentation to address the following issues: https://issues.apache.org/jira/browse/SPARK-5836 https://issues.apache.org/jira/browse/SPARK-3441 https://issues.apache.org/jira/browse/SPARK-5750 https://issues.apache.org/jira/browse/SPARK-4227 is related but can be addressed in a separate PR since it involves updates to the Spark Configuration Guide. Author: Ilya Ganelin Author: Ilya Ganelin Closes #5074 from ilganeli/SPARK-5750 and squashes the following commits: 6178e24 [Ilya Ganelin] Update programming-guide.md 7a0b96f [Ilya Ganelin] Update programming-guide.md 2c5df08 [Ilya Ganelin] Merge branch 'SPARK-5750' of github.com:ilganeli/spark into SPARK-5750 dffbd2d [Ilya Ganelin] [SPARK-5750] Slight wording update 1ff4eb4 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5750 85f9c6e [Ilya Ganelin] Update programming-guide.md 349d1fa [Ilya Ganelin] Added cross linkf or configuration page eeb5a7a [Ilya Ganelin] [SPARK-5750] Added some minor fixes dd5cc9d [Ilya Ganelin] [SPARK-5750] Fixed some factual inaccuracies with regards to shuffle internals. a8adb57 [Ilya Ganelin] [SPARK-5750] Incoporated feedback from Sean Owen 9954bbe [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5750 159dd1c [Ilya Ganelin] [SPARK-5750] Style fixes from rxin. 75ef67b [Ilya Ganelin] [SPARK-5750][SPARK-3441][SPARK-5836] Added documentation explaining the shuffle operation and included errata from a number of other JIRAs --- docs/programming-guide.md | 83 +++++++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 11 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index f5b775da7930a..f4fabb0927b66 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -937,7 +937,7 @@ for details. Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item). - mapPartitions(func) + mapPartitions(func) Similar to map, but runs separately on each partition (block) of the RDD, so func must be of type Iterator<T> => Iterator<U> when running on an RDD of type T. @@ -964,7 +964,7 @@ for details. Return a new dataset that contains the distinct elements of the source dataset. - groupByKey([numTasks]) + groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will yield much better @@ -975,25 +975,25 @@ for details. - reduceByKey(func, [numTasks]) + reduceByKey(func, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - sortByKey([ascending], [numTasks]) + sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. - join(otherDataset, [numTasks]) + join(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. - cogroup(otherDataset, [numTasks]) + cogroup(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. @@ -1006,17 +1006,17 @@ for details. process's stdin and lines output to its stdout are returned as an RDD of strings. - coalesce(numPartitions) + coalesce(numPartitions) Decrease the number of partitions in the RDD to numPartitions. Useful for running operations more efficiently after filtering down a large dataset. repartition(numPartitions) Reshuffle the data in the RDD randomly to create either more or fewer partitions and balance it across them. - This always shuffles all data over the network. + This always shuffles all data over the network. - repartitionAndSortWithinPartitions(partitioner) + repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1080,7 +1080,7 @@ for details. SparkContext.objectFile(). - countByKey() + countByKey() Only available on RDDs of type (K, V). Returns a hashmap of (K, Int) pairs with the count of each key. @@ -1090,6 +1090,67 @@ for details. +### Shuffle operations + +Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's +mechanism for re-distributing data so that is grouped differently across partitions. This typically +involves copying data across executors and machines, making the shuffle a complex and +costly operation. + +#### Background + +To understand what happens during the shuffle we can consider the example of the +[`reduceByKey`](#ReduceByLink) operation. The `reduceByKey` operation generates a new RDD where all +values for a single key are combined into a tuple - the key and the result of executing a reduce +function against all values associated with that key. The challenge is that not all values for a +single key necessarily reside on the same partition, or even the same machine, but they must be +co-located to compute the result. + +In Spark, data is generally not distributed across partitions to be in the necessary place for a +specific operation. During computations, a single task will operate on a single partition - thus, to +organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - +this is called the **shuffle**. + +Although the set of elements in each partition of newly shuffled data will be deterministic, and so +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: + +* `mapPartitions` to sort each partition using, for example, `.sorted` +* `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning +* `sortBy` to make a globally ordered RDD + +Operations which can cause a shuffle include **repartition** operations like +[`repartition`](#RepartitionLink), and [`coalesce`](#CoalesceLink), **'ByKey** operations +(except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and +**join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). + +#### Performance Impact +The **Shuffle** is an expensive operation since it involves disk I/O, data serialization, and +network I/O. To organize data for the shuffle, Spark generates sets of tasks - *map* tasks to +organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from +MapReduce and does not directly relate to Spark's `map` and `reduce` operations. + +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks +read the relevant sorted blocks. + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables +to disk, incurring the additional overhead of disk I/O and increased garbage collection. + +Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files +are not cleaned up from Spark's temporary storage until Spark is stopped, which means that +long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need +to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +`spark.local.dir` configuration parameter when configuring the Spark context. + +Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). + ## RDD Persistence One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory From 19d4c392fa1738e5dd04418cb008abc8810b8122 Mon Sep 17 00:00:00 2001 From: Jose Manuel Gomez Date: Mon, 30 Mar 2015 14:59:08 +0100 Subject: [PATCH 078/171] [HOTFIX] Update start-slave.sh wihtout this change the below error happens when I execute sbin/start-all.sh localhost: /spark-1.3/sbin/start-slave.sh: line 32: unexpected EOF while looking for matching `"' localhost: /spark-1.3/sbin/start-slave.sh: line 33: syntax error: unexpected end of file my operating system is Linux Mint 17.1 Rebecca Author: Jose Manuel Gomez Closes #5262 from josegom/patch-2 and squashes the following commits: 453af8b [Jose Manuel Gomez] Update start-slave.sh 2c456bd [Jose Manuel Gomez] Update start-slave.sh --- sbin/start-slave.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index c0155384f7395..5a6de11afdd3d 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -19,7 +19,7 @@ # Starts a slave on the machine this script is executed on. -usage="Usage: start-slave.sh where is like "spark://localhost:7077" +usage="Usage: start-slave.sh where is like spark://localhost:7077" if [ $# -lt 2 ]; then echo $usage From fe81f6c779213a91369ec61cf5489ad5c66cc49c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 30 Mar 2015 22:24:12 +0800 Subject: [PATCH 079/171] [SPARK-6595][SQL] MetastoreRelation should be a MultiInstanceRelation Now that we have `DataFrame`s it is possible to have multiple copies in a single query plan. As such, it needs to inherit from `MultiInstanceRelation` or self joins will break. I also add better debugging errors when our self join handling fails in case there are future bugs. Author: Michael Armbrust Closes #5251 from marmbrus/multiMetaStore and squashes the following commits: 4272f6d [Michael Armbrust] [SPARK-6595][SQL] MetastoreRelation should be MuliInstanceRelation --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 +++++++++- .../sql/catalyst/analysis/MultiInstanceRelation.scala | 2 +- .../apache/spark/sql/hive/HiveMetastoreCatalog.scala | 11 +++++++++-- .../spark/sql/hive/HiveMetastoreCatalogSuite.scala | 8 ++++++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 44eceb0b372e6..ba1ac141b9fab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -252,7 +252,15 @@ class Analyzer(catalog: Catalog, case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) - }.head // Only handle first case found, others will be fixed on the next pass. + }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. + sys.error( + s""" + |Failure when resolving conflicting references in Join: + |$plan + | + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + """.stripMargin) + } val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 894c3500cf533..35b74024a4cab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -30,5 +30,5 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * of itself with globally unique expression ids. */ trait MultiInstanceRelation { - def newInstance(): this.type + def newInstance(): LogicalPlan } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d1a99555e90c6..203164ea84292 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.Logging import org.apache.spark.sql.{SaveMode, AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Catalog, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NoSuchTableException, Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical @@ -697,7 +697,7 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) (@transient sqlContext: SQLContext) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { self: Product => @@ -778,6 +778,13 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + + override def newInstance() = { + val newCopy = MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) + // The project here is an ugly hack to work around the fact that MetastoreRelation's + // equals method is broken. Please remove this when SPARK-6555 is fixed. + Project(newCopy.output, newCopy) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index aad48ada52642..fa8e11ffec2b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.hive.test.TestHive import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT @@ -36,4 +37,11 @@ class HiveMetastoreCatalogSuite extends FunSuite { assert(HiveMetastoreTypes.toMetastoreType(udt) === HiveMetastoreTypes.toMetastoreType(udt.sqlType)) } + + test("duplicated metastore relations") { + import TestHive.implicits._ + val df = TestHive.sql("SELECT * FROM src") + println(df.queryExecution) + df.as('a).join(df.as('b), $"a.key" === $"b.key") + } } From 32259c671ab419f4c8a6ba8e2f7d676c5dfd0f4f Mon Sep 17 00:00:00 2001 From: CodingCat Date: Mon, 30 Mar 2015 11:54:44 -0700 Subject: [PATCH 080/171] [SPARK-6592][SQL] fix filter for scaladoc to generate API doc for Row class under catalyst dir https://issues.apache.org/jira/browse/SPARK-6592 The current impl in SparkBuild.scala filter all classes under catalyst directory, however, we have a corner case that Row class is a public API under that directory we need to include Row into the scaladoc while still excluding other classes of catalyst project Thanks for the help on this patch from rxin and liancheng Author: CodingCat Closes #5252 from CodingCat/SPARK-6592 and squashes the following commits: 02098a4 [CodingCat] ignore collection, enable types (except those protected classes) f7af2cb [CodingCat] commit 3ab4403 [CodingCat] fix filter for scaladoc to generate API doc for Row.scala under catalyst directory --- project/SparkBuild.scala | 16 ++++++++-------- .../spark/sql/types/DataTypeConversions.scala | 2 +- .../apache/spark/sql/types/DataTypeParser.scala | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ac37c605de4b6..d3faa551a4b14 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -360,15 +360,15 @@ object Unidoc { packages .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("akka"))) - .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) - .map(_.filterNot(_.getCanonicalPath.contains("network"))) - .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) - .map(_.filterNot(_.getCanonicalPath.contains("executor"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) - .map(_.filterNot(_.getCanonicalPath.contains("collection"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/catalyst"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/execution"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/hive/test"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) } lazy val settings = scalaJavaUnidocSettings ++ Seq ( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala index c243be07a91b6..a9d63e784963d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -protected[sql] object DataTypeConversions { +private[sql] object DataTypeConversions { def productToRow(product: Product, schema: StructType): Row = { val mutableRow = new GenericMutableRow(product.productArity) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 89278f7dbc806..34270d0ca7cd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -112,4 +112,4 @@ private[sql] object DataTypeParser { } /** The exception thrown from the [[DataTypeParser]]. */ -protected[sql] class DataTypeException(message: String) extends Exception(message) +private[sql] class DataTypeException(message: String) extends Exception(message) From df3550084c9975f999ed370dd9f7c495181a68ba Mon Sep 17 00:00:00 2001 From: Brennon York Date: Mon, 30 Mar 2015 12:48:26 -0700 Subject: [PATCH 081/171] [HOTFIX][SPARK-4123]: Updated to fix bug where multiple dependencies added breaks Github output Currently there is a bug whereby if a new patch introduces more than one new dependency (or removes more than one) it breaks the Github post output (see [this build](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/29399/consoleFull)). This hotfix will remove `awk` print statements in place of `printf` so as not to automatically add the newline character which is then escaped and added directly at the end of the `awk` statement. This should take a failed build output such as: ```json data: {"body": " [Test build #29400 has finished](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/29400/consoleFull) for PR 5266 at commit [`2aa4be0`](https://github.com/apache/spark/commit/2aa4be0e1d7ce052f8c901c6d9462c611c3a920a).\n * This patch **passes all tests**.\n * This patch merges cleanly.\n * This patch adds the following public classes _(experimental)_:\n * `class IDF extends Estimator[IDFModel] with IDFParams `\n * `class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] `\n\n * This patch **adds the following new dependencies:**\n * `avro-1.7.7.jar` * `breeze-macros_2.10-0.11.2.jar` * `breeze_2.10-0.11.2.jar`\n * This patch **removes the following dependencies:**\n * `avro-1.7.6.jar` * `breeze-macros_2.10-0.11.1.jar` * `breeze_2.10-0.11.1.jar`"} ``` and turn it into: ```json data: {"body": " [Test build #29400 has finished](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/29400/consoleFull) for PR 5266 at commit [`2aa4be0`](https://github.com/apache/spark/commit/2aa4be0e1d7ce052f8c901c6d9462c611c3a920a).\n * This patch **passes all tests**.\n * This patch merges cleanly.\n * This patch adds the following public classes _(experimental)_:\n * `class IDF extends Estimator[IDFModel] with IDFParams `\n * `class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] `\n\n * This patch **adds the following new dependencies:**\n * `avro-1.7.7.jar`\n * `breeze-macros_2.10-0.11.2.jar`\n * `breeze_2.10-0.11.2.jar`\n * This patch **removes the following dependencies:**\n * `avro-1.7.6.jar`\n * `breeze-macros_2.10-0.11.1.jar`\n * `breeze_2.10-0.11.1.jar`"} ``` I've tested this locally and all worked. /cc srowen pwendell nchammas Author: Brennon York Closes #5269 from brennonyork/HOTFIX-SPARK-4123 and squashes the following commits: a441068 [Brennon York] Updated awk to use printf and to manually insert newlines so that the JSON github string when posted is corrected --- dev/tests/pr_new_dependencies.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh index 115a5cd1354f0..370c7cc737bbd 100755 --- a/dev/tests/pr_new_dependencies.sh +++ b/dev/tests/pr_new_dependencies.sh @@ -90,8 +90,8 @@ else echo " * This patch does not change any dependencies." else # Pretty print the new dependencies - added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{print " * \`"$1"\`"}') - removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{print " * \`"$1"\`"}') + added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') + removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') added_deps_text=" * This patch **adds the following new dependencies:**\n${added_deps}" removed_deps_text=" * This patch **removes the following dependencies:**\n${removed_deps}" From f76d2e55b1a67bf5576e1aa001a0b872b9b3895a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Mar 2015 15:47:00 -0700 Subject: [PATCH 082/171] [SPARK-6603] [PySpark] [SQL] add SQLContext.udf and deprecate inferSchema() and applySchema This PR create an alias for `registerFunction` as `udf.register`, to be consistent with Scala API. It also deprecated inferSchema() and applySchema(), show an warning for them. cc rxin Author: Davies Liu Closes #5273 from davies/udf and squashes the following commits: 476e947 [Davies Liu] address comments c096fdb [Davies Liu] add SQLContext.udf and deprecate inferSchema() and applySchema --- python/pyspark/sql/context.py | 87 ++++++++++++++++++++++++----------- 1 file changed, 60 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 795ef0dbc4c47..80939a1f8ab1e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,7 +34,7 @@ except ImportError: has_pandas = False -__all__ = ["SQLContext", "HiveContext"] +__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] def _monkey_patch_RDD(sqlCtx): @@ -56,6 +56,31 @@ def toDF(self, schema=None, sampleRatio=None): RDD.toDF = toDF +class UDFRegistration(object): + """Wrapper for register UDF""" + + def __init__(self, sqlCtx): + self.sqlCtx = sqlCtx + + def register(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.udf.register("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + """ + return self.sqlCtx.registerFunction(name, f, returnType) + + class SQLContext(object): """Main entry point for Spark SQL functionality. @@ -118,6 +143,11 @@ def getConf(self, key, defaultValue): """ return self._ssql_ctx.getConf(key, defaultValue) + @property + def udf(self): + """Wrapper for register Python function as UDF """ + return UDFRegistration(self) + def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -198,14 +228,12 @@ def inferSchema(self, rdd, samplingRatio=None): >>> df.collect()[0] Row(field1=1, field2=u'row1') """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") - schema = self._inferSchema(rdd, samplingRatio) - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) + return self.createDataFrame(rdd, None, samplingRatio) def applySchema(self, rdd, schema): """ @@ -230,6 +258,7 @@ def applySchema(self, rdd, schema): >>> df.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] """ + warnings.warn("applySchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -237,23 +266,7 @@ def applySchema(self, rdd, schema): if not isinstance(schema, StructType): raise TypeError("schema should be StructType, but got %s" % schema) - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + return self.createDataFrame(rdd, schema) def createDataFrame(self, data, schema=None, samplingRatio=None): """ @@ -323,22 +336,42 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if not isinstance(data, RDD): try: # data could be list, tuple, generator ... - data = self._sc.parallelize(data) + rdd = self._sc.parallelize(data) except Exception: raise ValueError("cannot create an RDD from type: %s" % type(data)) + else: + rdd = data if schema is None: - return self.inferSchema(data, samplingRatio) + schema = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(schema) + rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): - first = data.first() + first = rdd.first() if not isinstance(first, (list, tuple)): raise ValueError("each row in `rdd` should be list or tuple, " "but got %r" % type(first)) row_cls = Row(*schema) - schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio) + schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - return self.applySchema(data, schema) + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) def registerDataFrameAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. From fde6945417355ae57500b67d034c9cad4f20d240 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 31 Mar 2015 07:48:37 +0800 Subject: [PATCH 083/171] [SPARK-6369] [SQL] Uses commit coordinator to help committing Hive and Parquet tables This PR leverages the output commit coordinator introduced in #4066 to help committing Hive and Parquet tables. This PR extracts output commit code in `SparkHadoopWriter.commit` to `SparkHadoopMapRedUtil.commitTask`, and reuses it for committing Parquet and Hive tables on executor side. TODO - [ ] Add tests [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5139) Author: Cheng Lian Closes #5139 from liancheng/spark-6369 and squashes the following commits: 72eb628 [Cheng Lian] Fixes typo in javadoc 9a4b82b [Cheng Lian] Adds javadoc and addresses @aarondav's comments dfdf3ef [Cheng Lian] Uses commit coordinator to help committing Hive and Parquet tables --- .../org/apache/spark/SparkHadoopWriter.scala | 52 +---------- .../spark/mapred/SparkHadoopMapRedUtil.scala | 91 ++++++++++++++++++- .../sql/parquet/ParquetTableOperations.scala | 11 ++- .../apache/spark/sql/parquet/newParquet.scala | 4 +- .../hive/execution/InsertIntoHiveTable.scala | 1 - .../spark/sql/hive/hiveWriterContainers.scala | 17 +--- 6 files changed, 103 insertions(+), 73 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6eb4537d10477..2ec42d3aea169 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path -import org.apache.spark.executor.CommitDeniedException import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD @@ -104,55 +103,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - - // Called after we have decided to commit - def performCommit(): Unit = { - try { - cmtr.commitTask(taCtxt) - logInfo (s"$taID: Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } - - // First, check whether the task's output has already been committed by some other attempt - if (cmtr.needsTaskCommit(taCtxt)) { - // The task output needs to be committed, but we don't know whether some other task attempt - // might be racing to commit the same output partition. Therefore, coordinate with the driver - // in order to determine whether this attempt can commit (see SPARK-4879). - val shouldCoordinateWithDriver: Boolean = { - val sparkConf = SparkEnv.get.conf - // We only need to coordinate with the driver if there are multiple concurrent task - // attempts, which should only occur if speculation is enabled - val speculationEnabled = sparkConf.getBoolean("spark.speculation", false) - // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs - sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) - } - if (shouldCoordinateWithDriver) { - val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID) - if (canCommit) { - performCommit() - } else { - val msg = s"$taID: Not committed because the driver did not authorize commit" - logInfo(msg) - // We need to abort the task so that the driver can reschedule new attempts, if necessary - cmtr.abortTask(taCtxt) - throw new CommitDeniedException(msg, jobID, splitID, attemptID) - } - } else { - // Speculation is disabled or a user has chosen to manually bypass the commit coordination - performCommit() - } - } else { - // Some other attempt committed the output, so we do nothing and signal success - logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}") - } + SparkHadoopMapRedUtil.commitTask( + getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 87c2aa481095d..818f7a4c8d422 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -17,9 +17,15 @@ package org.apache.spark.mapred +import java.io.IOException import java.lang.reflect.Modifier -import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} + +import org.apache.spark.executor.CommitDeniedException +import org.apache.spark.{Logging, SparkEnv, TaskContext} private[spark] trait SparkHadoopMapRedUtil { @@ -65,3 +71,86 @@ trait SparkHadoopMapRedUtil { } } } + +object SparkHadoopMapRedUtil extends Logging { + /** + * Commits a task output. Before committing the task output, we need to know whether some other + * task attempt might be racing to commit the same output partition. Therefore, coordinate with + * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for + * details). + * + * Output commit coordinator is only contacted when the following two configurations are both set + * to `true`: + * + * - `spark.speculation` + * - `spark.hadoop.outputCommitCoordination.enabled` + */ + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + jobId: Int, + splitId: Int, + attemptId: Int): Unit = { + + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + + // Called after we have decided to commit + def performCommit(): Unit = { + try { + committer.commitTask(mrTaskContext) + logInfo(s"$mrTaskAttemptID: Committed") + } catch { + case cause: IOException => + logError(s"Error committing the output of task: $mrTaskAttemptID", cause) + committer.abortTask(mrTaskContext) + throw cause + } + } + + // First, check whether the task's output has already been committed by some other attempt + if (committer.needsTaskCommit(mrTaskContext)) { + val shouldCoordinateWithDriver: Boolean = { + val sparkConf = SparkEnv.get.conf + // We only need to coordinate with the driver if there are multiple concurrent task + // attempts, which should only occur if speculation is enabled + val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false) + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + } + + if (shouldCoordinateWithDriver) { + val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + + if (canCommit) { + performCommit() + } else { + val message = + s"$mrTaskAttemptID: Not committed because the driver did not authorize commit" + logInfo(message) + // We need to abort the task so that the driver can reschedule new attempts, if necessary + committer.abortTask(mrTaskContext) + throw new CommitDeniedException(message, jobId, splitId, attemptId) + } + } else { + // Speculation is disabled or a user has chosen to manually bypass the commit coordination + performCommit() + } + } else { + // Some other attempt committed the output, so we do nothing and signal success + logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") + } + } + + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + sparkTaskContext: TaskContext): Unit = { + commitTask( + committer, + mrTaskContext, + sparkTaskContext.stageId(), + sparkTaskContext.partitionId(), + sparkTaskContext.attemptNumber()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 5130d8ad5e003..1c868da23e060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.parquet import java.io.IOException import java.lang.{Long => JLong} -import java.text.SimpleDateFormat -import java.text.NumberFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.concurrent.{Callable, TimeUnit} -import java.util.{ArrayList, Collections, Date, List => JList} +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -43,12 +42,13 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** @@ -356,7 +356,7 @@ private[sql] case class InsertIntoParquetTable( } finally { writer.close(hadoopContext) } - committer.commitTask(hadoopContext) + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) 1 } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) @@ -512,6 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat import parquet.filter2.compat.FilterCompat.Filter import parquet.filter2.compat.RowGroupFilter + import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 53f765ee26a13..19800ad88c031 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -42,6 +42,7 @@ import parquet.hadoop.{ParquetInputFormat, _} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions @@ -669,7 +670,8 @@ private[sql] case class ParquetRelation2( } finally { writer.close(hadoopContext) } - committer.commitTask(hadoopContext) + + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) /* apparently we need a TaskAttemptID to construct an OutputCommitter; diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index da53d30354551..cdf012b5117be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -72,7 +72,6 @@ case class InsertIntoHiveTable( val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName assert(outputFileFormatClassName != null, "Output format class not set") conf.value.set("mapred.output.format.class", outputFileFormatClassName) - conf.value.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath( conf.value, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ba2bf67aed684..8398da268174d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import java.io.IOException import java.text.NumberFormat import java.util.Date @@ -118,19 +117,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - if (committer.needsTaskCommit(taskContext)) { - try { - committer.commitTask(taskContext) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - committer.abortTask(taskContext) - throw e - } - } else { - logInfo("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { @@ -213,7 +200,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => val string = if (rawVal == null) null else String.valueOf(rawVal) - val colString = + val colString = if (string == null || string.isEmpty) { defaultPartName } else { From b8ff2bc61c9835867f56afa1860ab5eb727c4a58 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 30 Mar 2015 20:47:10 -0700 Subject: [PATCH 084/171] [SPARK-6119][SQL] DataFrame support for missing data handling This pull request adds variants of DataFrame.na.drop and DataFrame.na.fill to the Scala/Java API, and DataFrame.fillna and DataFrame.dropna to the Python API. Author: Reynold Xin Closes #5274 from rxin/df-missing-value and squashes the following commits: 4ee1b98 [Reynold Xin] Improve error reporting in Python. 33a330c [Reynold Xin] Remove replace for now. bc4fdbb [Reynold Xin] Added documentation for replace. d56f5a5 [Reynold Xin] Added replace for Scala/Java. 2385d00 [Reynold Xin] Feedback from Xiangrui on "how". 914a374 [Reynold Xin] fill with map. 185c67e [Reynold Xin] Allow specifying column subsets in fill. 749eb47 [Reynold Xin] fillna 249b94e [Reynold Xin] Removing undefined functions. 6a73c68 [Reynold Xin] Missing file. 67d7003 [Reynold Xin] [SPARK-6119][SQL] DataFrame.na.drop (Scala/Java) and DataFrame.dropna (Python) --- python/pyspark/sql/dataframe.py | 86 +++++++ python/pyspark/sql/tests.py | 96 ++++++++ .../catalyst/expressions/nullFunctions.scala | 25 +- .../org/apache/spark/sql/DataFrame.scala | 15 +- .../spark/sql/DataFrameNaFunctions.scala | 228 ++++++++++++++++++ .../org/apache/spark/sql/GroupedData.scala | 5 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 157 ++++++++++++ 8 files changed, 606 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 23c0e63e77812..4f174de811697 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -690,6 +690,86 @@ def subtract(self, other): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + def dropna(self, how='any', thresh=None, subset=None): + """Returns a new :class:`DataFrame` omitting rows with null values. + + :param how: 'any' or 'all'. + If 'any', drop a row if it contains any nulls. + If 'all', drop a row only if all its values are null. + :param thresh: int, default None + If specified, drop rows that have less than `thresh` non-null values. + This overwrites the `how` parameter. + :param subset: optional list of column names to consider. + + >>> df4.dropna().show() + age height name + 10 80 Alice + """ + if how is not None and how not in ['any', 'all']: + raise ValueError("how ('" + how + "') should be 'any' or 'all'") + + if subset is None: + subset = self.columns + elif isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + if thresh is None: + thresh = len(subset) if how == 'any' else 1 + + cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) + cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) + return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) + + def fillna(self, value, subset=None): + """Replace null values. + + :param value: int, long, float, string, or dict. + Value to replace null values with. + If the value is a dict, then `subset` is ignored and `value` must be a mapping + from column name (string) to replacement value. The replacement value must be + an int, long, float, or string. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + + >>> df4.fillna(50).show() + age height name + 10 80 Alice + 5 50 Bob + 50 50 Tom + 50 50 null + + >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + """ + if not isinstance(value, (float, int, long, basestring, dict)): + raise ValueError("value should be a float, int, long, string, or dict") + + if isinstance(value, (int, long)): + value = float(value) + + if isinstance(value, dict): + value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + elif subset is None: + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + else: + if isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) + cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) + return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) + def withColumn(self, colName, col): """ Return a new :class:`DataFrame` by adding a column. @@ -1069,6 +1149,12 @@ def _test(): globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), Row(name='Bob', age=5, height=85)]).toDF() + + globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]).toDF() + (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2720439416682..258464b7f230d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -415,6 +415,102 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_dropna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # shouldn't drop a non-null row + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, 80.1)], schema).dropna().count(), + 1) + + # dropping rows with a single null value + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna().count(), + 0) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), + 0) + + # if how = 'all', only drop rows if all values are null + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(None, None, None)], schema).dropna(how='all').count(), + 0) + + # how and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 0) + + # threshold + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(thresh=2).count(), + 0) + + # threshold and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 0) + + # thresh should take precedence over how + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna( + how='any', thresh=2, subset=['name', 'age']).count(), + 1) + + def test_fillna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # fillna shouldn't change non-null values + row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + self.assertEqual(row.age, 10) + + # fillna with int + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.0) + + # fillna with double + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.1) + + # fillna with string + row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + self.assertEqual(row.name, u"hello") + self.assertEqual(row.age, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, 50) + self.assertEqual(row.height, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + self.assertEqual(row.name, "haha") + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index d1f3d4f4ee9ee..f9161cf34f0c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -35,7 +35,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def toString: String = s"Coalesce(${children.mkString(",")})" - def dataType: DataType = if (resolved) { + override def dataType: DataType = if (resolved) { children.head.dataType } else { val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") @@ -74,3 +74,26 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E child.eval(input) != null } } + +/** + * A predicate that is evaluated to be true if there are at least `n` non-null values. + */ +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = false + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: Row): Boolean = { + var numNonNulls = 0 + var i = 0 + while (i < childrenArray.length && numNonNulls < n) { + if (childrenArray(i).eval(input) != null) { + numNonNulls += 1 + } + i += 1 + } + numNonNulls >= n + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 423ef3912bc89..5cd0a18ff688c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -237,8 +237,8 @@ class DataFrame private[sql]( def toDF(colNames: String*): DataFrame = { require(schema.size == colNames.size, "The number of columns doesn't match.\n" + - "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + - "New column names: " + colNames.mkString(", ")) + s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + + s"New column names (${colNames.size}): " + colNames.mkString(", ")) val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) => apply(oldName).as(newName) @@ -319,6 +319,17 @@ class DataFrame private[sql]( */ def show(): Unit = show(20) + /** + * Returns a [[DataFrameNaFunctions]] for working with missing data. + * {{{ + * // Dropping rows containing any null values. + * df.na.drop() + * }}} + * + * @group dfops + */ + def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) + /** * Cartesian join with another [[DataFrame]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala new file mode 100644 index 0000000000000..3a3dc70f7285c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -0,0 +1,228 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.{lang => jl} + +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + + +/** + * Functionality for working with missing data in [[DataFrame]]s. + */ +final class DataFrameNaFunctions private[sql](df: DataFrame) { + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values. + */ + def drop(): DataFrame = drop("any", df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values. + * + * If `how` is "any", then drop rows containing any null values. + * If `how` is "all", then drop rows only if every column is null for that row. + */ + def drop(how: String): DataFrame = drop(how, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Array[String]): DataFrame = drop(cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Seq[String]): DataFrame = { + how.toLowerCase match { + case "any" => drop(cols.size, cols) + case "all" => drop(1, cols) + case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") + } + } + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values. + */ + def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null + * values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than + * `minNonNulls` non-null values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { + // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values. + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + df.filter(Column(predicate)) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`. + */ + def fill(value: Double): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`. + */ + def fill(value: String): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame]] that replaces null values in specified numeric columns. + * If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified + * numeric columns. If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[Double](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in specified string columns. + * If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in + * specified string columns. If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[String](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * import com.google.common.collect.ImmutableMap; + * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); + * }}} + */ + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * df.na.fill(Map( + * "A" -> "unknown", + * "B" -> 1.0 + * )) + * }}} + */ + def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + private def fill0(values: Seq[(String, Any)]): DataFrame = { + // Error handling + values.foreach { case (colName, replaceValue) => + // Check column name exists + df.resolve(colName) + + // Check data type + replaceValue match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String => + // This is good + case _ => throw new IllegalArgumentException( + s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") + } + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => + v match { + case v: jl.Float => fillCol[Double](f, v.toDouble) + case v: jl.Double => fillCol[Double](f, v) + case v: jl.Long => fillCol[Double](f, v.toDouble) + case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: String => fillCol[String](f, v) + } + }.getOrElse(df.col(f.name)) + } + df.select(projections : _*) + } + + /** + * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. + */ + private def fillCol[T](col: StructField, replacement: T): Column = { + coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45a63ae26ed71..a5e6b638d2150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -127,10 +127,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.builder() - * .put("age", "max") - * .put("expense", "sum") - * .build()); + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); * }}} */ def agg(exprs: java.util.Map[String, String]): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 2b0358c4e2a1e..0b770f2251943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -49,7 +49,7 @@ private[sql] object JsonRDD extends Logging { val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = if (schemaData.isEmpty()) { - Set.empty[(String,DataType)] + Set.empty[(String, DataType)] } else { parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala new file mode 100644 index 0000000000000..0896f175c056f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.test.TestSQLContext.implicits._ + + +class DataFrameNaFunctionsSuite extends QueryTest { + + def createDF(): DataFrame = { + Seq[(String, java.lang.Integer, java.lang.Double)]( + ("Bob", 16, 176.5), + ("Alice", null, 164.3), + ("David", 60, null), + ("Amy", null, null), + (null, null, null)).toDF("name", "age", "height") + } + + test("drop") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("name" :: Nil), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("age" :: Nil), + rows(0) :: rows(2) :: Nil) + + checkAnswer( + input.na.drop("age" :: "height" :: Nil), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(), + rows(0)) + + // dropna on an a dataframe with no column should return an empty data frame. + val empty = input.sqlContext.emptyDataFrame.select() + assert(empty.na.drop().count() === 0L) + + // Make sure the columns are properly named. + assert(input.na.drop().columns.toSeq === input.columns.toSeq) + } + + test("drop with how") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("all"), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("any"), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("any", Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("all", Seq("age", "height")), + rows(0) :: rows(1) :: rows(2) :: Nil) + } + + test("drop with threshold") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop(2, Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(3, Seq("name", "age", "height")), + rows(0)) + + // Make sure the columns are properly named. + assert(input.na.drop(2, Seq("age", "height")).columns.toSeq === input.columns.toSeq) + } + + test("fill") { + val input = createDF() + + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil), + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, null) :: + Row("Amy", 50, null) :: + Row(null, 50, null) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + } + + test("fill with map") { + val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( + (null, null, null, null)).toDF("a", "b", "c", "d") + checkAnswer( + df.na.fill(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + )), + Row("test", null, 1, 2.2)) + + // Test Java version + checkAnswer( + df.na.fill(mapAsJavaMap(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + ))), + Row("test", null, 1, 2.2)) + } +} From 56775571cb938c819e5f7c3d49c5dd416ed034cb Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 30 Mar 2015 22:10:49 -0700 Subject: [PATCH 085/171] [SPARK-5124][Core] Move StopCoordinator to the receive method since it does not require a reply Hotfix for #4588 cc rxin Author: zsxwing Closes #5283 from zsxwing/hotfix and squashes the following commits: cf3e5a7 [zsxwing] Move StopCoordinator to the receive method since it does not require a reply --- .../spark/scheduler/OutputCommitCoordinator.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index f748f394d1347..17055e2f22d0d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -156,14 +156,16 @@ private[spark] object OutputCommitCoordinator { override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) extends RpcEndpoint with Logging { + override def receive: PartialFunction[Any, Unit] = { + case StopCoordinator => + logInfo("OutputCommitCoordinator stopped!") + stop() + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => context.reply( outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) - case StopCoordinator => - logInfo("OutputCommitCoordinator stopped!") - context.reply(true) - stop() } } } From f07e714062f02feadff10a45f9b9061444bb8ec5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 31 Mar 2015 00:19:51 -0700 Subject: [PATCH 086/171] [SPARK-6625][SQL] Add common string filters to data sources. Filters such as startsWith, endsWith, contains will be very useful for data sources that provide search functionality, e.g. Succinct, Elastic Search, Solr. I also took this chance to improve documentation for the data source filters. Author: Reynold Xin Closes #5285 from rxin/ds-string-filters and squashes the following commits: f021727 [Reynold Xin] Fixed grammar. 7695a52 [Reynold Xin] [SPARK-6625][SQL] Add common string filters to data sources. --- .../sql/sources/DataSourceStrategy.scala | 10 +++ .../apache/spark/sql/sources/filters.scala | 69 ++++++++++++++++++ .../apache/spark/sql/sources/interfaces.scala | 3 + .../spark/sql/sources/FilteredScanSuite.scala | 73 +++++++++++++------ 4 files changed, 133 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 67f3507c61ab6..83b603a4bb245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.StringType import org.apache.spark.sql.{Row, Strategy, execution, sources} /** @@ -166,6 +167,15 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.Not(child) => translate(child).map(sources.Not) + case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringStartsWith(a.name, v)) + + case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringEndsWith(a.name, v)) + + case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringContains(a.name, v)) + case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 1e4505e36d2f0..791046e0079d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,16 +17,85 @@ package org.apache.spark.sql.sources +/** + * A filter predicate for data sources. + */ abstract class Filter +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * equal to `value`. + */ case class EqualTo(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than `value`. + */ case class GreaterThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than or equal to `value`. + */ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than `value`. + */ case class LessThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than or equal to `value`. + */ case class LessThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array. + */ case class In(attribute: String, values: Array[Any]) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to null. + */ case class IsNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a non-null value. + */ case class IsNotNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff both `left` or `right` evaluate to `true`. + */ case class And(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff at least one of `left` or `right` evaluates to `true`. + */ case class Or(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff `child` is evaluated to `false`. + */ case class Not(child: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringStartsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringEndsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that contains the string `value`. + */ +case class StringContains(attribute: String, value: String) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index a046a48c1733d..8f9946a5a801e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -152,6 +152,9 @@ trait PrunedScan { * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * + * The actual filter should be the conjunction of all `filters`, + * i.e. they should be "and" together. + * * The pushed down filters are currently purely an optimization as they will all be evaluated * again. This means it is safe to use them with methods that produce false positives such * as filtering partitions based on a bloom filter. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index ffeccf0b69394..72ddc0ea2c8cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -35,20 +35,23 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL extends BaseRelation with PrunedFilteredScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: - StructField("b", IntegerType, nullable = false) :: Nil) + StructField("b", IntegerType, nullable = false) :: + StructField("c", StringType, nullable = false) :: Nil) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) + case "c" => (i: Int) => Seq((i - 1 + 'a').toChar.toString * 10) } FiltersPushed.list = filters - def translateFilter(filter: Filter): Int => Boolean = filter match { + // Predicate test on integer column + def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v @@ -57,13 +60,27 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) case IsNull("a") => (a: Int) => false // Int can't be null case IsNotNull("a") => (a: Int) => true - case Not(pred) => (a: Int) => !translateFilter(pred)(a) - case And(left, right) => (a: Int) => translateFilter(left)(a) && translateFilter(right)(a) - case Or(left, right) => (a: Int) => translateFilter(left)(a) || translateFilter(right)(a) + case Not(pred) => (a: Int) => !translateFilterOnA(pred)(a) + case And(left, right) => (a: Int) => + translateFilterOnA(left)(a) && translateFilterOnA(right)(a) + case Or(left, right) => (a: Int) => + translateFilterOnA(left)(a) || translateFilterOnA(right)(a) case _ => (a: Int) => true } - def eval(a: Int) = !filters.map(translateFilter(_)(a)).contains(false) + // Predicate test on string column + def translateFilterOnC(filter: Filter): String => Boolean = filter match { + case StringStartsWith("c", v) => _.startsWith(v) + case StringEndsWith("c", v) => _.endsWith(v) + case StringContains("c", v) => _.contains(v) + case _ => (c: String) => true + } + + def eval(a: Int) = { + val c = (a - 1 + 'a').toChar.toString * 10 + !filters.map(translateFilterOnA(_)(a)).contains(false) && + !filters.map(translateFilterOnC(_)(c)).contains(false) + } sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) @@ -93,7 +110,7 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2)).toSeq) + (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 10)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -128,41 +145,53 @@ class FilteredScanSuite extends DataSourceTest { (2 to 10 by 2).map(i => Row(i, i)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", + Seq(1,3,5).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE A = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE b = 2", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IS NULL", + "SELECT a, b FROM oneToTenFiltered WHERE a IS NULL", Seq.empty[Row]) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IS NOT NULL", + "SELECT a, b FROM oneToTenFiltered WHERE a IS NOT NULL", (1 to 10).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", + "SELECT a, b FROM oneToTenFiltered WHERE a < 5 AND a > 1", (2 to 4).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", - Seq(1, 2, 9, 10).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a < 3 OR a > 8", + Seq(1, 2, 9, 10).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", + "SELECT a, b FROM oneToTenFiltered WHERE NOT (a < 6)", (6 to 10).map(i => Row(i, i * 2)).toSeq) + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", + Seq(Row(3, 3 * 2, "c" * 10))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'd%'", + Seq(Row(4, 4 * 2, "d" * 10))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%e%'", + Seq(Row(5, 5 * 2, "e" * 10))) + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) From b80a030e90d790e27e89b26f536565c582dbf3d5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 31 Mar 2015 00:25:23 -0700 Subject: [PATCH 087/171] [SPARK-6623][SQL] Alias DataFrame.na.drop and DataFrame.na.fill in Python. To maintain consistency with the Scala API. Author: Reynold Xin Closes #5284 from rxin/df-na-alias and squashes the following commits: 19f46b7 [Reynold Xin] Show DataFrameNaFunctions in docs. 6618118 [Reynold Xin] [SPARK-6623][SQL] Alias DataFrame.na.drop and DataFrame.na.fill in Python. --- python/pyspark/sql/__init__.py | 10 ++++---- python/pyspark/sql/dataframe.py | 41 +++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 54a01631d8899..9d39e5d9c2449 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -22,22 +22,24 @@ Main entry point for :class:`DataFrame` and SQL functionality. - L{DataFrame} A distributed collection of data grouped into named columns. - - L{GroupedData} - Aggregation methods, returned by :func:`DataFrame.groupBy`. - L{Column} A column expression in a :class:`DataFrame`. - L{Row} A row of data in a :class:`DataFrame`. - L{HiveContext} Main entry point for accessing data stored in Apache Hive. + - L{GroupedData} + Aggregation methods, returned by :func:`DataFrame.groupBy`. + - L{DataFrameNaFunctions} + Methods for handling missing data (null values). - L{functions} List of built-in functions available for :class:`DataFrame`. """ from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.types import Row -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions __all__ = [ - 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions' ] diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4f174de811697..15508023326cc 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -31,7 +31,7 @@ from pyspark.sql.types import _create_cls, _parse_datatype_json_string -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"] +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] class DataFrame(object): @@ -86,6 +86,12 @@ def applySchema(it): return self._lazy_rdd + @property + def na(self): + """Returns a :class:`DataFrameNaFunctions` for handling missing values. + """ + return DataFrameNaFunctions(self) + def toJSON(self, use_unicode=False): """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row. @@ -693,6 +699,8 @@ def subtract(self, other): def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. + This is an alias for `na.drop`. + :param how: 'any' or 'all'. If 'any', drop a row if it contains any nulls. If 'all', drop a row only if all its values are null. @@ -704,6 +712,10 @@ def dropna(self, how='any', thresh=None, subset=None): >>> df4.dropna().show() age height name 10 80 Alice + + >>> df4.na.drop().show() + age height name + 10 80 Alice """ if how is not None and how not in ['any', 'all']: raise ValueError("how ('" + how + "') should be 'any' or 'all'") @@ -723,7 +735,7 @@ def dropna(self, how='any', thresh=None, subset=None): return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) def fillna(self, value, subset=None): - """Replace null values. + """Replace null values, alias for `na.fill`. :param value: int, long, float, string, or dict. Value to replace null values with. @@ -748,6 +760,13 @@ def fillna(self, value, subset=None): 5 null Bob 50 null Tom 50 null unknown + + >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown """ if not isinstance(value, (float, int, long, basestring, dict)): raise ValueError("value should be a float, int, long, string, or dict") @@ -1134,6 +1153,24 @@ def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') +class DataFrameNaFunctions(object): + """Functionality for working with missing data in :class:`DataFrame`. + """ + + def __init__(self, df): + self.df = df + + def drop(self, how='any', thresh=None, subset=None): + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + drop.__doc__ = DataFrame.dropna.__doc__ + + def fill(self, value, subset=None): + return self.df.fillna(value=value, subset=subset) + + fill.__doc__ = DataFrame.fillna.__doc__ + + def _test(): import doctest from pyspark.context import SparkContext From 314afd0e2f08dd8d3333d3143712c2c79fa40d1e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 31 Mar 2015 16:28:40 +0800 Subject: [PATCH 088/171] [SPARK-6618][SQL] HiveMetastoreCatalog.lookupRelation should use fine-grained lock JIRA: https://issues.apache.org/jira/browse/SPARK-6618 Author: Yin Huai Closes #5281 from yhuai/lookupRelationLock and squashes the following commits: 591b4be [Yin Huai] A test? b3a9625 [Yin Huai] Just protect client. --- .../apache/spark/sql/hive/HiveMetastoreCatalog.scala | 12 +++++++++--- .../spark/sql/hive/execution/SQLQuerySuite.scala | 11 +++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 203164ea84292..6a01a23124d95 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -172,12 +172,16 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def lookupRelation( tableIdentifier: Seq[String], - alias: Option[String]): LogicalPlan = synchronized { + alias: Option[String]): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - val table = try client.getTable(databaseName, tblName) catch { + val table = try { + synchronized { + client.getTable(databaseName, tblName) + } + } catch { case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException => throw new NoSuchTableException } @@ -199,7 +203,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } else { val partitions: Seq[Partition] = if (table.isPartitioned) { - HiveShim.getAllPartitionsOf(client, table).toSeq + synchronized { + HiveShim.getAllPartitionsOf(client, table).toSeq + } } else { Nil } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1187228f4c3db..2f50a33448462 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -433,4 +433,15 @@ class SQLQuerySuite extends QueryTest { dropTempTable("data") setConf("spark.sql.hive.convertCTAS", originalConf) } + + test("sanity test for SPARK-6618") { + (1 to 100).par.map { i => + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + catalog.lookupRelation(Seq(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } } From a05835b89fe2086e460f0b80f7c22e284c0c32d0 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 31 Mar 2015 17:05:23 +0800 Subject: [PATCH 089/171] [SPARK-6542][SQL] add CreateStruct Similar to `CreateArray`, we can add `CreateStruct` to create nested columns. marmbrus Author: Xiangrui Meng Closes #5195 from mengxr/SPARK-6542 and squashes the following commits: 3795c57 [Xiangrui Meng] update error message ae7ac3e [Xiangrui Meng] move unit test to a separate suite 85dd559 [Xiangrui Meng] use NamedExpr c78e31a [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-6542 85f3106 [Xiangrui Meng] add CreateStruct --- .../sql/catalyst/analysis/Analyzer.scala | 6 ++ .../catalyst/expressions/complexTypes.scala | 29 ++++++++- .../ExpressionEvaluationSuite.scala | 61 ++++++++++++------- 3 files changed, 73 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ba1ac141b9fab..dc14f49e6ee99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -212,6 +212,12 @@ class Analyzer(catalog: Catalog, case o => o :: Nil } Alias(c.copy(children = expandedArgs), name)() :: Nil + case Alias(c @ CreateStruct(args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(c.copy(children = expandedArgs), name)() :: Nil case o => o :: Nil }, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 3fd78db297462..3b2b9211268a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -142,3 +142,30 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def toString: String = s"Array(${children.mkString(",")})" } + +/** + * Returns a Row containing the evaluation of all children expressions. + * TODO: [[CreateStruct]] does not support codegen. + */ +case class CreateStruct(children: Seq[NamedExpression]) extends Expression { + override type EvaluatedType = Row + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + assert(resolved, + s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") + val fields = children.map { child => + StructField(child.name, child.dataType, child.nullable, child.metadata) + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: Row): EvaluatedType = { + Row(children.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index dcfd8b28cb02a..1183a0d899dda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -30,7 +30,34 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.types._ -class ExpressionEvaluationSuite extends FunSuite { +class ExpressionEvaluationBaseSuite extends FunSuite { + + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + expression.eval(inputRow) + } + + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkDoubleEvaluation( + expression: Expression, + expected: Spread[Double], + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } +} + +class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("literals") { checkEvaluation(Literal(1), 1) @@ -134,27 +161,6 @@ class ExpressionEvaluationSuite extends FunSuite { } } - def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.eval(inputRow) - } - - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected - } - test("IN") { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) @@ -1081,3 +1087,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(~c1, -2, row) } } + +// TODO: Make the tests work with codegen. +class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { + + test("CreateStruct") { + val row = Row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) + } +} From d01a6d8c33fc5c8325b0cc4b51395dba5eb3462c Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Tue, 31 Mar 2015 11:16:55 -0700 Subject: [PATCH 090/171] [SPARK-4894][mllib] Added Bernoulli option to NaiveBayes model in mllib Added optional model type parameter for NaiveBayes training. Can be either Multinomial or Bernoulli. When Bernoulli is given the Bernoulli smoothing is used for fitting and for prediction as per: http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html. Default for model is original Multinomial fit and predict. Added additional testing for Bernoulli and Multinomial models. Author: leahmcguire Author: Joseph K. Bradley Author: Leah McGuire Closes #4087 from leahmcguire/master and squashes the following commits: f3c8994 [leahmcguire] changed checks on model type to requires acb69af [leahmcguire] removed enum type and replaces all modelType parameters with strings 2224b15 [Leah McGuire] Merge pull request #2 from jkbradley/leahmcguire-master 9ad89ca [Joseph K. Bradley] removed old code 6a8f383 [Joseph K. Bradley] Added new model save/load format 2.0 for NaiveBayesModel after modelType parameter was added. Updated tests. Also updated ModelType enum-like type. 852a727 [leahmcguire] merged with upstream master a22d670 [leahmcguire] changed NaiveBayesModel modelType parameter back to NaiveBayes.ModelType, made NaiveBayes.ModelType serializable, fixed getter method in NavieBayes 18f3219 [leahmcguire] removed private from naive bayes constructor for lambda only bea62af [leahmcguire] put back in constructor for NaiveBayes 01baad7 [leahmcguire] made fixes from code review fb0a5c7 [leahmcguire] removed typo e2d925e [leahmcguire] fixed nonserializable error that was causing naivebayes test failures 2d0c1ba [leahmcguire] fixed typo in NaiveBayes c298e78 [leahmcguire] fixed scala style errors b85b0c9 [leahmcguire] Merge remote-tracking branch 'upstream/master' 900b586 [leahmcguire] fixed model call so that uses type argument ea09b28 [leahmcguire] Merge remote-tracking branch 'upstream/master' e016569 [leahmcguire] updated test suite with model type fix 85f298f [leahmcguire] Merge remote-tracking branch 'upstream/master' dc65374 [leahmcguire] integrated model type fix 7622b0c [leahmcguire] added comments and fixed style as per rb b93aaf6 [Leah McGuire] Merge pull request #1 from jkbradley/nb-model-type 3730572 [Joseph K. Bradley] modified NB model type to be more Java-friendly b61b5e2 [leahmcguire] added back compatable constructor to NaiveBayesModel to fix MIMA test failure 5a4a534 [leahmcguire] fixed scala style error in NaiveBayes 3891bf2 [leahmcguire] synced with apache spark and resolved merge conflict d9477ed [leahmcguire] removed old inaccurate comment from test suite for mllib naive bayes 76e5b0f [leahmcguire] removed unnecessary sort from test 0313c0c [leahmcguire] fixed style error in NaiveBayes.scala 4a3676d [leahmcguire] Updated changes re-comments. Got rid of verbose populateMatrix method. Public api now has string instead of enumeration. Docs are updated." ce73c63 [leahmcguire] added Bernoulli option to niave bayes model in mllib, added optional model type parameter for training. When Bernoulli is given the Bernoulli smoothing is used for fitting and for prediction http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html --- docs/mllib-naive-bayes.md | 17 +- .../mllib/classification/NaiveBayes.scala | 225 ++++++++++++++---- .../classification/JavaNaiveBayesSuite.java | 23 +- .../classification/NaiveBayesSuite.scala | 148 +++++++++--- 4 files changed, 322 insertions(+), 91 deletions(-) diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index a83472f5be52e..9780ea52c4994 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -13,12 +13,15 @@ compute the conditional probability distribution of label given an observation and use it for prediction. MLlib supports [multinomial naive -Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), -which is typically used for [document -classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) +and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification] +(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each -feature represents a term whose value is the frequency of the term. -Feature values must be nonnegative to represent term frequencies. +feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). +Feature values must be nonnegative. The model type is selected with an optional parameter +"Multinomial" or "Bernoulli" with "Multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of @@ -32,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional -smoothing parameter `lambda` as input, and output a +smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. @@ -51,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0) +val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index d60e82c410979..c9b3ff0172e2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,9 +21,12 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.numerics.{exp => brzExp, log => brzLog} + import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} @@ -32,6 +35,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} + /** * Model for Naive Bayes Classifiers. * @@ -39,11 +43,17 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features + * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" */ class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { + val theta: Array[Array[Double]], + val modelType: String) + extends ClassificationModel with Serializable with Saveable { + + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = + this(labels, pi, theta, "Multinomial") /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( @@ -53,19 +63,19 @@ class NaiveBayesModel private[mllib] ( this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) private val brzPi = new BDV[Double](pi) - private val brzTheta = new BDM[Double](theta.length, theta(0).length) - - { - // Need to put an extra pair of braces to prevent Scala treating `i` as a member. - var i = 0 - while (i < theta.length) { - var j = 0 - while (j < theta(i).length) { - brzTheta(i, j) = theta(i)(j) - j += 1 - } - i += 1 - } + private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t + + // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // application of this condition (in predict function). + private val (brzNegTheta, brzNegThetaSum) = modelType match { + case "Multinomial" => (None, None) + case "Bernoulli" => + val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) + (Option(negTheta), Option(brzSum(negTheta, Axis._1))) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") } override def predict(testData: RDD[Vector]): RDD[Double] = { @@ -77,22 +87,78 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { - labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) + modelType match { + case "Multinomial" => + labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) + case "Bernoulli" => + labels (brzArgmax (brzPi + + (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") + } } override def save(sc: SparkContext, path: String): Unit = { - val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta) - NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) + NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = "2.0" } object NaiveBayesModel extends Loader[NaiveBayesModel] { import org.apache.spark.mllib.util.Loader._ - private object SaveLoadV1_0 { + private[mllib] object SaveLoadV2_0 { + + def thisFormatVersion: String = "2.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]], + modelType: String) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + val modelType = data.getString(3) + new NaiveBayesModel(labels, pi, theta, modelType) + } + + } + + private[mllib] object SaveLoadV1_0 { def thisFormatVersion: String = "1.0" @@ -100,7 +166,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" /** Model data for model import/export */ - case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { val sqlContext = new SQLContext(sc) @@ -136,26 +205,32 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { override def load(sc: SparkContext, path: String): NaiveBayesModel = { val (loadedClassName, version, metadata) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName - (loadedClassName, version) match { + val classNameV2_0 = SaveLoadV2_0.thisClassName + val (model, numFeatures, numClasses) = (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) val model = SaveLoadV1_0.load(sc, path) - assert(model.pi.size == numClasses, - s"NaiveBayesModel.load expected $numClasses classes," + - s" but class priors vector pi had ${model.pi.size} elements") - assert(model.theta.size == numClasses, - s"NaiveBayesModel.load expected $numClasses classes," + - s" but class conditionals array theta had ${model.theta.size} elements") - assert(model.theta.forall(_.size == numFeatures), - s"NaiveBayesModel.load expected $numFeatures features," + - s" but class conditionals array theta had elements of size:" + - s" ${model.theta.map(_.size).mkString(",")}") - model + (model, numFeatures, numClasses) + case (className, "2.0") if className == classNameV2_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV2_0.load(sc, path) + (model, numFeatures, numClasses) case _ => throw new Exception( s"NaiveBayesModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + s" ($classNameV1_0, 1.0)") } + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + model } } @@ -167,9 +242,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ -class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { - def this() = this(1.0) +class NaiveBayes private ( + private var lambda: Double, + private var modelType: String) extends Serializable with Logging { + + def this(lambda: Double) = this(lambda, "Multinomial") + + def this() = this(1.0, "Multinomial") /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -177,9 +257,24 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with this } - /** Get the smoothing parameter. Default: 1.0. */ + /** Get the smoothing parameter. */ def getLambda: Double = lambda + /** + * Set the model type using a string (case-sensitive). + * Supported options: "Multinomial" and "Bernoulli". + * (default: Multinomial) + */ + def setModelType(modelType:String): NaiveBayes = { + require(NaiveBayes.supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown ModelType: $modelType") + this.modelType = modelType + this + } + + /** Get the model type. */ + def getModelType: String = this.modelType + /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * @@ -213,21 +308,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => (c1._1 + c2._1, c1._2 += c2._2) ).collect() + val numLabels = aggregated.length var numDocuments = 0L aggregated.foreach { case (_, (n, _)) => numDocuments += n } val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } + val labels = new Array[Double](numLabels) val pi = new Array[Double](numLabels) val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) + val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => labels(i) = label - val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda) pi(i) = math.log(n + lambda) - piLogDenom + val thetaLogDenom = modelType match { + case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case "Bernoulli" => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") + } var j = 0 while (j < numFeatures) { theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom @@ -236,7 +340,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with i += 1 } - new NaiveBayesModel(labels, pi, theta) + new NaiveBayesModel(labels, pi, theta, modelType) } } @@ -244,13 +348,16 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with * Top-level methods for calling naive Bayes. */ object NaiveBayes { + + /* Set of modelTypes that NaiveBayes supports */ + private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") + /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * This version of the method uses a default smoothing parameter of 1.0. * @@ -264,16 +371,40 @@ object NaiveBayes { /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda).run(input) + new NaiveBayes(lambda, "Multinomial").run(input) + } + + /** + * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. + * + * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]]) + * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle + * discrete count data and can be called by setting the model type to "multinomial". + * For example, it can be used with word counts or TF_IDF vectors of documents. + * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a + * 0-1 vector and setting the model type to "bernoulli", the fits and predicts as + * Bernoulli NB. + * + * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency + * vector or a count vector. + * @param lambda The smoothing parameter + * + * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be + * multinomial or bernoulli + */ + def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { + require(supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown ModelType: $modelType") + new NaiveBayes(lambda, modelType).run(input) } + } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 1c90522a0714a..71fb7f13c39c2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -17,20 +17,22 @@ package org.apache.spark.mllib.classification; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; public class JavaNaiveBayesSuite implements Serializable { private transient JavaSparkContext sc; @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception { // Should be able to get the first prediction. predictions.first(); } + + @Test + public void testModelTypeSetters() { + NaiveBayes nb = new NaiveBayes() + .setModelType("Bernoulli") + .setModelType("Multinomial"); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 5a27c7d2309c5..f9fe3e006ccb8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.mllib.classification import scala.util.Random +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.stats.distributions.{Multinomial => BrzMultinomial} + import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -41,37 +44,48 @@ object NaiveBayesSuite { // Generate input of the form Y = (theta * x).argmax() def generateNaiveBayesInput( - pi: Array[Double], // 1XC - theta: Array[Array[Double]], // CXD - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + pi: Array[Double], // 1XC + theta: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int, + modelType: String = "Multinomial", + sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) - val _pi = pi.map(math.pow(math.E, _)) val _theta = theta.map(row => row.map(math.pow(math.E, _))) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) - val xi = Array.tabulate[Double](D) { j => - if (rnd.nextDouble() < _theta(y)(j)) 1 else 0 + val xi = modelType match { + case "Bernoulli" => Array.tabulate[Double] (D) { j => + if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 + } + case "Multinomial" => + val mult = BrzMultinomial(BDV(_theta(y))) + val emptyMap = (0 until D).map(x => (x, 0.0)).toMap + val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { + case (index, reps) => (index, reps.size.toDouble) + } + counts.toArray.sortBy(_._1).map(_._2) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") } LabeledPoint(y, Vectors.dense(xi)) } } - private val smallPi = Array(0.5, 0.3, 0.2).map(math.log) + /** Bernoulli NaiveBayes with binary labels, 3 features */ + private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "Bernoulli") - private val smallTheta = Array( - Array(0.91, 0.03, 0.03, 0.03), // label 0 - Array(0.03, 0.91, 0.03, 0.03), // label 1 - Array(0.03, 0.03, 0.91, 0.03) // label 2 - ).map(_.map(math.log)) - - /** Binary labels, 3 features */ - private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), - theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4))) + /** Multinomial NaiveBayes with binary labels, 3 features */ + private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "Multinomial") } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -85,6 +99,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(numOfPredictions < input.length / 5) } + def validateModelFit( + piData: Array[Double], + thetaData: Array[Array[Double]], + model: NaiveBayesModel) = { + def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { + (d1 - d2).abs <= precision + } + val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) + for (i <- modelIndex) { + assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) + } + for (i <- modelIndex) { + for (j <- 0 until thetaData(i._2).length) { + assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) + } + } + } + test("get, set params") { val nb = new NaiveBayes() nb.setLambda(2.0) @@ -93,19 +125,53 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(nb.getLambda === 3.0) } - test("Naive Bayes") { - val nPoints = 10000 + test("Naive Bayes Multinomial") { + val nPoints = 1000 + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 42, "Multinomial") + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") + validateModelFit(pi, theta, model) + + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 17, "Multinomial") + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - val pi = NaiveBayesSuite.smallPi - val theta = NaiveBayesSuite.smallTheta + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } - val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) + test("Naive Bayes Bernoulli") { + val nPoints = 10000 + val pi = Array(0.5, 0.3, 0.2).map(math.log) + val theta = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 45, "Bernoulli") val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD) + val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") + validateModelFit(pi, theta, model) - val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17) + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 20, "Bernoulli") val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -142,19 +208,41 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } - test("model save/load") { - val model = NaiveBayesSuite.binaryModel + test("model save/load: 2.0 to 2.0") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Seq(NaiveBayesSuite.binaryBernoulliModel, NaiveBayesSuite.binaryMultinomialModel).map { + model => + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + assert(model.modelType === sameModel.modelType) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("model save/load: 1.0 to 2.0") { + val model = NaiveBayesSuite.binaryMultinomialModel val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString - // Save model, load it back, and compare. + // Save model as version 1.0, load it back, and compare. try { - model.save(sc, path) + val data = NaiveBayesModel.SaveLoadV1_0.Data(model.labels, model.pi, model.theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) val sameModel = NaiveBayesModel.load(sc, path) assert(model.labels === sameModel.labels) assert(model.pi === sameModel.pi) assert(model.theta === sameModel.theta) + assert(model.modelType === "Multinomial") } finally { Utils.deleteRecursively(tempDir) } @@ -172,8 +260,8 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble()))) } } - // If we serialize data directly in the task closure, the size of the serialized task would be - // greater than 1MB and hence Spark would throw an error. + // If we serialize data directly in the task closure, the size of the serialized task + // would be greater than 1MB and hence Spark would throw an error. val model = NaiveBayes.train(examples) val predictions = model.predict(examples.map(_.features)) } From a7992ffaf1e8adc9d2c225a986fa3162e8e130eb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 31 Mar 2015 11:18:25 -0700 Subject: [PATCH 091/171] [SPARK-6555] [SQL] Overrides equals() and hashCode() for MetastoreRelation Also removes temporary workarounds made in #5183 and #5251. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5289) Author: Cheng Lian Closes #5289 from liancheng/spark-6555 and squashes the following commits: d0095ac [Cheng Lian] Removes unused imports cfafeeb [Cheng Lian] Removes outdated comment 75a2746 [Cheng Lian] Overrides equals() and hashCode() for MetastoreRelation --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 42 +++++++++++-------- .../sql/hive/execution/HivePlanTest.scala | 6 ++- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6a01a23124d95..f20f0ad99f865 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} +import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.hive.metastore.api.{FieldSchema, Partition => TPartition, Table => TTable} import org.apache.hadoop.hive.metastore.{TableType, Warehouse} @@ -465,7 +466,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) // Write path case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _) @@ -476,7 +477,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) // Read path case p @ PhysicalOperation(_, _, relation: MetastoreRelation) @@ -485,33 +486,28 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) } - // Quick fix for SPARK-6450: Notice that we're using both the MetastoreRelation instances and - // their output attributes as the key of the map. This is because MetastoreRelation.equals - // doesn't take output attributes into account, thus multiple MetastoreRelation instances - // pointing to the same table get collapsed into a single entry in the map. A proper fix for - // this should be overriding equals & hashCode in MetastoreRelation. val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes // attribute IDs referenced in other nodes. plan.transformUp { - case r: MetastoreRelation if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) + case r: MetastoreRelation if relationMap.contains(r) => + val parquetRelation = relationMap(r) val alias = r.alias.getOrElse(r.tableName) Subquery(alias, parquetRelation) case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) InsertIntoTable(parquetRelation, partition, child, overwrite) case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) InsertIntoTable(parquetRelation, partition, child, overwrite) case other => other.transformExpressions { @@ -707,6 +703,19 @@ private[hive] case class MetastoreRelation self: Product => + override def equals(other: scala.Any): Boolean = other match { + case relation: MetastoreRelation => + databaseName == relation.databaseName && + tableName == relation.tableName && + alias == relation.alias && + output == relation.output + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode(databaseName, tableName, alias, output) + } + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. // Right now, using org.apache.hadoop.hive.ql.metadata.Table and @@ -786,10 +795,7 @@ private[hive] case class MetastoreRelation val columnOrdinals = AttributeMap(attributes.zipWithIndex) override def newInstance() = { - val newCopy = MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) - // The project here is an ugly hack to work around the fact that MetastoreRelation's - // equals method is broken. Please remove this when SPARK-6555 is fixed. - Project(newCopy.output, newCopy) + MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index c939e6e99d28a..bdb53ddf59c19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -22,10 +22,12 @@ import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { import TestHive._ + import TestHive.implicits._ test("udf constant folding") { - val optimized = sql("SELECT cos(null) FROM src").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM src").queryExecution.optimizedPlan + Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") + val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } From 81020144708773ba3af4932288ffa09ef901269e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 31 Mar 2015 11:21:15 -0700 Subject: [PATCH 092/171] [SPARK-6575] [SQL] Adds configuration to disable schema merging while converting metastore Parquet tables Consider a metastore Parquet table that 1. doesn't have schema evolution issue 2. has lots of data files and/or partitions In this case, driver schema merging can be both slow and unnecessary. Would be good to have a configuration to let the use disable schema merging when converting such a metastore Parquet table. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5231) Author: Cheng Lian Closes #5231 from liancheng/spark-6575 and squashes the following commits: cd96159 [Cheng Lian] Adds configuration to disable schema merging while converting metastore Parquet tables --- .../org/apache/spark/sql/hive/HiveContext.scala | 9 +++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 16 ++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c06c2e396bbc1..6bb1c47dba920 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -57,6 +57,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] def convertMetastoreParquet: Boolean = getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + /** + * When true, also tries to merge possibly different but compatible Parquet schemas in different + * Parquet data files. + * + * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. + */ + protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = + getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true" + /** * When true, a table created by a Hive CTAS statement (no USING clause) will be * converted to a data source table, using the data source set by spark.sql.sources.default. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f20f0ad99f865..2b5d031741a63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -218,6 +218,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging + val parquetOptions = Map( + ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the @@ -234,18 +238,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } val partitionSpec = PartitionSpec(partitionSchema, partitions) val paths = partitions.map(_.path) - LogicalRelation( - ParquetRelation2( - paths, - Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json), - None, - Some(partitionSpec))(hive)) + LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) } else { val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) - LogicalRelation( - ParquetRelation2( - paths, - Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) + LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) } } From cd48ca50129e8952f487051796244e7569275416 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 31 Mar 2015 11:23:18 -0700 Subject: [PATCH 093/171] [SPARK-6145][SQL] fix ORDER BY on nested fields This PR is based on work by cloud-fan in #4904, but with two differences: - We isolate the logic for Sort's special handling into `ResolveSortReferences` - We avoid creating UnresolvedGetField expressions during resolution. Instead we either resolve GetField or we return None. This avoids us going down the wrong path early on. Author: Michael Armbrust Closes #5189 from marmbrus/nestedOrderBy and squashes the following commits: b8cae45 [Michael Armbrust] fix another test 0f36a11 [Michael Armbrust] WIP 91820cd [Michael Armbrust] Fix bug. --- .../sql/catalyst/analysis/Analyzer.scala | 76 ++++++++++++++----- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++- .../catalyst/expressions/AttributeSet.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 76 +++++++++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 39 +++++++++- .../org/apache/spark/sql/SQLContext.scala | 14 ++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 19 +++-- .../spark/sql/sources/DataSourceTest.scala | 4 + 8 files changed, 185 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dc14f49e6ee99..c578d084a45b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -37,11 +37,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and * a [[FunctionRegistry]]. */ -class Analyzer(catalog: Catalog, - registry: FunctionRegistry, - caseSensitive: Boolean, - maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { +class Analyzer( + catalog: Catalog, + registry: FunctionRegistry, + caseSensitive: Boolean, + maxIterations: Int = 100) + extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution @@ -354,19 +355,16 @@ class Analyzer(catalog: Catalog, def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => - val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) - val resolved = unresolved.flatMap(child.resolve(_, resolver)) - val requiredAttributes = - AttributeSet(resolved.flatMap(_.collect { case a: Attribute => a })) + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child) - val missingInProject = requiredAttributes -- p.output - if (missingInProject.nonEmpty) { + // If this rule was not a no-op, return the transformed plan, otherwise return the original. + if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. - Project(projectList.map(_.toAttribute), - Sort(ordering, global, - Project(projectList ++ missingInProject, child))) + Project(p.output, + Sort(resolvedOrdering, global, + Project(projectList ++ missing, child))) } else { - logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") + logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) @@ -378,18 +376,54 @@ class Analyzer(catalog: Catalog, grouping.collect { case ne: NamedExpression => ne.toAttribute } ) - logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver)) - val missingInAggs = resolved.filterNot(a.outputSet.contains) - logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") - if (missingInAggs.nonEmpty) { + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) + + if (missing.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child))) + Sort(resolvedOrdering, global, + Aggregate(grouping, aggs ++ missing, child))) } else { s // Nothing we can do here. Return original plan. } } + + /** + * Given a child and a grandchild that are present beneath a sort operator, returns + * a resolved sort ordering and a list of attributes that are missing from the child + * but are present in the grandchild. + */ + def resolveAndFindMissing( + ordering: Seq[SortOrder], + child: LogicalPlan, + grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + // Find any attributes that remain unresolved in the sort. + val unresolved: Seq[String] = + ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + + // Create a map from name, to resolved attributes, when the desired name can be found + // prior to the projection. + val resolved: Map[String, NamedExpression] = + unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap + + // Construct a set that contains all of the attributes that we need to evaluate the + // ordering. + val requiredAttributes = AttributeSet(resolved.values) + + // Figure out which ones are missing from the projection, so that we can add them and + // remove them after the sort. + val missingInProject = requiredAttributes -- child.output + + // Now that we have all the attributes we need, reconstruct a resolved ordering. + // It is important to do it here, instead of waiting for the standard resolved as adding + // attributes to the project below can actually introduce ambiquity that was not present + // before. + val resolvedOrdering = ordering.map(_ transform { + case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u) + }).asInstanceOf[Seq[SortOrder]] + + (resolvedOrdering, missingInProject.toSeq) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 40472a1cbb3b4..fa02111385c06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.types._ /** * Throws user facing errors when passed invalid queries that fail to analyze. */ -class CheckAnalysis { +trait CheckAnalysis { + self: Analyzer => /** * Override to provide additional checks for correct analysis. @@ -33,17 +34,22 @@ class CheckAnalysis { */ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil - def failAnalysis(msg: String): Nothing = { + protected def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } - def apply(plan: LogicalPlan): Unit = { + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => + if (operator.childrenResolved) { + // Throw errors for specific problems with get field. + operator.resolveChildren(a.name, resolver, throwErrors = true) + } + val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 11b4eb5c888be..5345696570b41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -34,7 +34,7 @@ object AttributeSet { def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ - def apply(baseSet: Seq[Expression]): AttributeSet = { + def apply(baseSet: Iterable[Expression]): AttributeSet = { new AttributeSet( baseSet .flatMap(_.references) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b01a61d7bf8d6..2e9f3aa4ec4ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.{ArrayType, StructType, StructField} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -109,16 +110,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, children.flatMap(_.output), resolver) + def resolveChildren( + name: String, + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(name, children.flatMap(_.output), resolver, throwErrors) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, output, resolver) + def resolve( + name: String, + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(name, output, resolver, throwErrors) /** * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. @@ -162,7 +169,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { protected def resolve( name: String, input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { + resolver: Resolver, + throwErrors: Boolean): Option[NamedExpression] = { val parts = name.split("\\.") @@ -196,14 +204,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - // The foldLeft adds UnresolvedGetField for every remaining parts of the name, - // and aliased it with the last part of the name. - // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias - // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) + try { + + // The foldLeft adds UnresolvedGetField for every remaining parts of the name, + // and aliased it with the last part of the name. + // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias + // the final expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver)) + val aliasName = nestedFields.last + Some(Alias(fieldExprs, aliasName)()) + } catch { + case a: AnalysisException if !throwErrors => None + } // No matches. case Seq() => @@ -212,11 +225,46 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // More than one match. case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") + val referenceNames = ambiguousReferences.map(_._1).mkString(", ") throw new AnalysisException( s"Reference '$name' is ambiguous, could be: $referenceNames.") } } + + /** + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + * + * TODO: this code is duplicated from Analyzer and should be refactored to avoid this. + */ + protected def resolveGetField( + expr: Expression, + fieldName: String, + resolver: Resolver): Expression = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 756cd36f05c8c..ee7b14c7a157c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -40,14 +40,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { override val extendedResolutionRules = EliminateSubQueries :: Nil } - val checkAnalysis = new CheckAnalysis - def caseSensitiveAnalyze(plan: LogicalPlan) = - checkAnalysis(caseSensitiveAnalyzer(plan)) + caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan)) def caseInsensitiveAnalyze(plan: LogicalPlan) = - checkAnalysis(caseInsensitiveAnalyzer(plan)) + caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( @@ -57,6 +55,21 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -169,6 +182,24 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { "'b'" :: "group by" :: Nil ) + errorTest( + "ambiguous field", + nestedRelation.select($"top.duplicateField"), + "Ambiguous reference to fields" :: "duplicateField" :: Nil, + caseSensitive = false) + + errorTest( + "ambiguous field due to case insensitivity", + nestedRelation.select($"top.differentCase"), + "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, + caseSensitive = false) + + errorTest( + "missing field", + nestedRelation2.select($"top.c"), + "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, + caseSensitive = false) + case class UnresolvedTestPlan() extends LeafNode { override lazy val resolved = false override def output = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b8100782ec937..1794936a52c6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -120,6 +120,10 @@ class SQLContext(@transient val sparkContext: SparkContext) ExtractPythonUdfs :: sources.PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } @transient @@ -1065,14 +1069,6 @@ class SQLContext(@transient val sparkContext: SparkContext) Batch("Add exchange", Once, AddExchange(self)) :: Nil } - @transient - protected[sql] lazy val checkAnalysis = new CheckAnalysis { - override val extendedCheckRules = Seq( - sources.PreWriteCheck(catalog) - ) - } - - protected[sql] def openSession(): SQLSession = { detachSession() val session = createSession() @@ -1105,7 +1101,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi protected[sql] class QueryExecution(val logical: LogicalPlan) { - def assertAnalyzed(): Unit = checkAnalysis(analyzed) + def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) lazy val analyzed: LogicalPlan = analyzer(logical) lazy val withCachedData: LogicalPlan = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a3c0076e16d6c..87e7cf8c8af9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1084,10 +1084,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-6145: ORDER BY test for nested fields") { jsonRDD(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder") - // These should be successfully analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY a.b").queryExecution.analyzed - sql("SELECT a.b FROM nestedOrder ORDER BY a.b").queryExecution.analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a").queryExecution.analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d").queryExecution.analyzed + + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) + } + + test("SPARK-6145: special cases") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 91c6367371f15..33c67355967dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -32,6 +32,10 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { override val extendedResolutionRules = PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } } } From 46de6c05e0619250346f0988e296849f8f93d2b1 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Tue, 31 Mar 2015 11:25:21 -0700 Subject: [PATCH 094/171] [SPARK-6598][MLLIB] Python API for IDFModel This is the sub-task of SPARK-6254. Wrapping IDFModel `idf` member function for pyspark. Author: lewuathe Closes #5264 from Lewuathe/SPARK-6598 and squashes the following commits: 1dc522c [lewuathe] [SPARK-6598] Python API for IDFModel --- python/pyspark/mllib/feature.py | 6 ++++++ python/pyspark/mllib/tests.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 0ffe092a07365..4bfe3014ef748 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -244,6 +244,12 @@ def transform(self, x): x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) + def idf(self): + """ + Returns the current IDF vector. + """ + return self.call('idf') + class IDF(object): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 155019638f806..3bb0f0ca68128 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -41,6 +41,7 @@ from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import IDF from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -620,6 +621,19 @@ def test_right_number_of_results(self): self.assertEqual(len(chi), num_cols) self.assertIsNotNone(chi[1000]) + +class FeatureTest(PySparkTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" From b5bd75d90a761199c3f9cb583c1fe48c8fda7780 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 31 Mar 2015 11:32:14 -0700 Subject: [PATCH 095/171] [SPARK-6255] [MLLIB] Support multiclass classification in Python API Python API parity check for classification and multiclass classification support, major disparities need to be added for Python: ```scala LogisticRegressionWithLBFGS setNumClasses setValidateData LogisticRegressionModel getThreshold numClasses numFeatures SVMWithSGD setValidateData SVMModel getThreshold ``` For users the greatest benefit in this PR is multiclass classification was supported by Python API. Users can train multiclass classification model and use it to predict in pyspark. Author: Yanbo Liang Closes #5137 from yanboliang/spark-6255 and squashes the following commits: 0bd531e [Yanbo Liang] address comments 444d5e2 [Yanbo Liang] LogisticRegressionModel.predict() optimization fc7990b [Yanbo Liang] address comments b0d9c63 [Yanbo Liang] Support Mulinomial LR model predict in Python API ded847c [Yanbo Liang] Python API parity check for classification (support multiclass classification) --- .../mllib/api/python/PythonMLLibAPI.scala | 22 ++- python/pyspark/mllib/classification.py | 134 ++++++++++++++---- python/pyspark/mllib/regression.py | 10 +- 3 files changed, 134 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 22fa684fd2895..662ec5fbed453 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -77,7 +77,13 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector): JList[Object] = { try { val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + if (model.isInstanceOf[LogisticRegressionModel]) { + val lrModel = model.asInstanceOf[LogisticRegressionModel] + List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses) + .map(_.asInstanceOf[Object]).asJava + } else { + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } } finally { data.rdd.unpersist(blocking = false) } @@ -190,9 +196,11 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) + .setValidateData(validateData) SVMAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -216,9 +224,11 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -242,9 +252,13 @@ private[python] class PythonMLLibAPI extends Serializable { regType: String, intercept: Boolean, corrections: Int, - tolerance: Double): JList[Object] = { + tolerance: Double, + validateData: Boolean, + numClasses: Int): JList[Object] = { val LogRegAlg = new LogisticRegressionWithLBFGS() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) + .setNumClasses(numClasses) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 6766f3ebb8894..2466e8ac43458 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -22,7 +22,7 @@ from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper from pyspark.mllib.util import Saveable, Loader, inherit_doc @@ -31,13 +31,13 @@ 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] -class LinearBinaryClassificationModel(LinearModel): +class LinearClassificationModel(LinearModel): """ - Represents a linear binary classification model that predicts to whether an - example is positive (1.0) or negative (0.0). + A private abstract class representing a multiclass classification model. + The categories are represented by int values: 0, 1, 2, etc. """ def __init__(self, weights, intercept): - super(LinearBinaryClassificationModel, self).__init__(weights, intercept) + super(LinearClassificationModel, self).__init__(weights, intercept) self._threshold = None def setThreshold(self, value): @@ -47,14 +47,26 @@ def setThreshold(self, value): Sets the threshold that separates positive predictions from negative predictions. An example with prediction score greater than or equal to this threshold is identified as an positive, and negative otherwise. + It is used for binary classification only. """ self._threshold = value + @property + def threshold(self): + """ + .. note:: Experimental + + Returns the threshold (if any) used for converting raw prediction scores + into 0/1 predictions. It is used for binary classification only. + """ + return self._threshold + def clearThreshold(self): """ .. note:: Experimental Clears the threshold so that `predict` will output raw prediction scores. + It is used for binary classification only. """ self._threshold = None @@ -66,7 +78,7 @@ def predict(self, test): raise NotImplementedError -class LogisticRegressionModel(LinearBinaryClassificationModel): +class LogisticRegressionModel(LinearClassificationModel): """A linear binary classification model derived from logistic regression. @@ -112,10 +124,39 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): ... os.removedirs(path) ... except: ... pass + >>> multi_class_data = [ + ... LabeledPoint(0.0, [0.0, 1.0, 0.0]), + ... LabeledPoint(1.0, [1.0, 0.0, 0.0]), + ... LabeledPoint(2.0, [0.0, 0.0, 1.0]) + ... ] + >>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3) + >>> mcm.predict([0.0, 0.5, 0.0]) + 0 + >>> mcm.predict([0.8, 0.0, 0.0]) + 1 + >>> mcm.predict([0.0, 0.0, 0.3]) + 2 """ - def __init__(self, weights, intercept): + def __init__(self, weights, intercept, numFeatures, numClasses): super(LogisticRegressionModel, self).__init__(weights, intercept) + self._numFeatures = int(numFeatures) + self._numClasses = int(numClasses) self._threshold = 0.5 + if self._numClasses == 2: + self._dataWithBiasSize = None + self._weightsMatrix = None + else: + self._dataWithBiasSize = self._coeff.size / (self._numClasses - 1) + self._weightsMatrix = self._coeff.toArray().reshape(self._numClasses - 1, + self._dataWithBiasSize) + + @property + def numFeatures(self): + return self._numFeatures + + @property + def numClasses(self): + return self._numClasses def predict(self, x): """ @@ -126,20 +167,38 @@ def predict(self, x): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) - margin = self.weights.dot(x) + self._intercept - if margin > 0: - prob = 1 / (1 + exp(-margin)) + if self.numClasses == 2: + margin = self.weights.dot(x) + self._intercept + if margin > 0: + prob = 1 / (1 + exp(-margin)) + else: + exp_margin = exp(margin) + prob = exp_margin / (1 + exp_margin) + if self._threshold is None: + return prob + else: + return 1 if prob > self._threshold else 0 else: - exp_margin = exp(margin) - prob = exp_margin / (1 + exp_margin) - if self._threshold is None: - return prob - else: - return 1 if prob > self._threshold else 0 + best_class = 0 + max_margin = 0.0 + if x.size + 1 == self._dataWithBiasSize: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i][0:x.size]) + \ + self._weightsMatrix[i][x.size] + if margin > max_margin: + max_margin = margin + best_class = i + 1 + else: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i]) + if margin > max_margin: + max_margin = margin + best_class = i + 1 + return best_class def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( - _py2java(sc, self._coeff), self.intercept) + _py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses) java_model.save(sc._jsc.sc(), path) @classmethod @@ -148,8 +207,10 @@ def load(cls, sc, path): sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) intercept = java_model.intercept() + numFeatures = java_model.numFeatures() + numClasses = java_model.numClasses() threshold = java_model.getThreshold().get() - model = LogisticRegressionModel(weights, intercept) + model = LogisticRegressionModel(weights, intercept, numFeatures, numClasses) model.setThreshold(threshold) return model @@ -158,7 +219,8 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.01, regType="l2", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False, + validateData=True): """ Train a logistic regression model on the given data. @@ -184,11 +246,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, or not of the augmented representation for training data (i.e. whether bias features are activated or not). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -197,7 +262,7 @@ class LogisticRegressionWithLBFGS(object): @classmethod def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", - intercept=False, corrections=10, tolerance=1e-4): + intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. @@ -223,6 +288,11 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType update (default: 10). :param tolerance: The convergence tolerance of iterations for L-BFGS (default: 1e-4). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) + :param numClasses: The number of classes (i.e., outcomes) a label can take + in Multinomial Logistic Regression (default: 2). >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -237,12 +307,20 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, float(regParam), regType, bool(intercept), int(corrections), - float(tolerance)) - + float(tolerance), bool(validateData), int(numClasses)) + + if initialWeights is None: + if numClasses == 2: + initialWeights = [0.0] * len(data.first().features) + else: + if intercept: + initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1) + else: + initialWeights = [0.0] * len(data.first().features) * (numClasses - 1) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) -class SVMModel(LinearBinaryClassificationModel): +class SVMModel(LinearClassificationModel): """A support vector machine. @@ -325,7 +403,8 @@ class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): + miniBatchFraction=1.0, initialWeights=None, regType="l2", + intercept=False, validateData=True): """ Train a support vector machine on the given data. @@ -351,11 +430,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, or not of the augmented representation for training data (i.e. whether bias features are activated or not). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 209f1ee473b5b..cd7310a64f4ae 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -167,13 +167,19 @@ def load(cls, sc, path): # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) if initial_weights is None: initial_weights = [0.0] * len(data.first().features) - weights, intercept = train_func(data, _convert_to_vector(initial_weights)) - return modelClass(weights, intercept) + if (modelClass == LogisticRegressionModel): + weights, intercept, numFeatures, numClasses = train_func( + data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept, numFeatures, numClasses) + else: + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): From beebb7ffc21c66ae3e4c615555194d1e19ede1bb Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 31 Mar 2015 11:34:29 -0700 Subject: [PATCH 096/171] [SPARK-5371][SQL] Propagate types after function conversion, before futher resolution Before it was possible for a query to flip back and forth from a resolved state, allowing resolution to propagate up before coercion had stabilized. The issue was that `ResolvedReferences` would run after `FunctionArgumentConversion`, but before `PropagateTypes` had run. This PR ensures we correctly `PropagateTypes` after any coercion has applied. Author: Michael Armbrust Closes #5278 from marmbrus/unionNull and squashes the following commits: dc3581a [Michael Armbrust] [SPARK-5371][SQL] Propogate types after function conversion / before futher resolution --- .../catalyst/analysis/HiveTypeCoercion.scala | 1 + .../plans/logical/basicOperators.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 26 ++++++++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 34ef7d28cc7f2..3c7b46e0702a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -78,6 +78,7 @@ trait HiveTypeCoercion { FunctionArgumentConversion :: CaseWhenCoercion :: Division :: + PropagateTypes :: Nil /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 190209238a4a5..8633e06093cf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -80,7 +80,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved: Boolean = childrenResolved && - !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } + left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2f50a33448462..2065f0d60d92f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -41,8 +41,32 @@ case class NestedArray1(a: NestedArray2) */ class SQLQuerySuite extends QueryTest { + test("SPARK-5371: union with null and sum") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + + val query = sql( + """ + |SELECT + | MIN(c1), + | MIN(c2) + |FROM ( + | SELECT + | SUM(c1) c1, + | NULL c2 + | FROM table1 + | UNION ALL + | SELECT + | NULL c1, + | SUM(c2) c2 + | FROM table1 + |) a + """.stripMargin) + checkAnswer(query, Row(1, 1) :: Nil) + } + test("explode nested Field") { - Seq(NestedArray1(NestedArray2(Seq(1,2,3)))).toDF.registerTempTable("nestedArray") + Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") checkAnswer( sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), Row(1) :: Row(2) :: Row(3) :: Nil) From 2036bc5993022da550f0cb1c0485ae92ec3e6fb0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 31 Mar 2015 13:18:07 -0700 Subject: [PATCH 097/171] [SPARK-6633][SQL] Should be "Contains" instead of "EndsWith" when constructing sources.StringContains Author: Liang-Chi Hsieh Closes #5299 from viirya/stringcontains and squashes the following commits: c1ece4c [Liang-Chi Hsieh] Should be Contains instead of EndsWith. --- .../scala/org/apache/spark/sql/sources/DataSourceStrategy.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 83b603a4bb245..e13759b7feb7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -173,7 +173,7 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => Some(sources.StringEndsWith(a.name, v)) - case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => + case expressions.Contains(a: Attribute, Literal(v: String, StringType)) => Some(sources.StringContains(a.name, v)) case _ => None From 0e00f12d33d28d064c166262b14e012a1aeaa7b0 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 31 Mar 2015 16:01:08 -0700 Subject: [PATCH 098/171] [SPARK-5692] [MLlib] Word2Vec save/load Word2Vec model now supports saving and loading. a] The Metadata stored in JSON format consists of "version", "classname", "vectorSize" and "numWords" b] The data stored in Parquet file format consists of an Array of rows with each row consisting of 2 columns, first being the word: String and the second, an Array of Floats. Author: MechCoder Closes #5291 from MechCoder/spark-5692 and squashes the following commits: 1142f3a [MechCoder] Add numWords to metaData bfe4c39 [MechCoder] [SPARK-5692] Word2Vec save/load --- .../apache/spark/mllib/feature/Word2Vec.scala | 87 ++++++++++++++++++- .../spark/mllib/feature/Word2VecSuite.scala | 26 ++++++ 2 files changed, 110 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 59a79e5c6a4ac..9ee7e4a66b535 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -25,14 +25,21 @@ import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.sql.{SQLContext, Row} /** * Entry in vocabulary @@ -422,7 +429,7 @@ class Word2Vec extends Serializable with Logging { */ @Experimental class Word2VecModel private[mllib] ( - private val model: Map[String, Array[Float]]) extends Serializable { + private val model: Map[String, Array[Float]]) extends Serializable with Saveable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -432,7 +439,13 @@ class Word2VecModel private[mllib] ( if (norm1 == 0 || norm2 == 0) return 0.0 blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 } - + + override protected def formatVersion = "1.0" + + def save(sc: SparkContext, path: String): Unit = { + Word2VecModel.SaveLoadV1_0.save(sc, path, model) + } + /** * Transforms a word to its vector representation * @param word a word @@ -475,7 +488,7 @@ class Word2VecModel private[mllib] ( .tail .toArray } - + /** * Returns a map of words to their vector representations. */ @@ -483,3 +496,71 @@ class Word2VecModel private[mllib] ( model } } + +@Experimental +object Word2VecModel extends Loader[Word2VecModel] { + + private object SaveLoadV1_0 { + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel" + + case class Data(word: String, vector: Array[Float]) + + def load(sc: SparkContext, path: String): Word2VecModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.parquetFile(dataPath) + + val dataArray = dataFrame.select("word", "vector").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap + new Word2VecModel(word2VecMap) + } + + def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]) = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val vectorSize = model.values.head.size + val numWords = model.size + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } + sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + } + + override def load(sc: SparkContext, path: String): Word2VecModel = { + + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedVectorSize = (metadata \ "vectorSize").extract[Int] + val expectedNumWords = (metadata \ "numWords").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, loadedVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path) + val vectorSize = model.getVectors.values.head.size + val numWords = model.getVectors.size + require(expectedVectorSize == vectorSize, + s"Word2VecModel requires each word to be mapped to a vector of size " + + s"$expectedVectorSize, got vector of size $vectorSize") + require(expectedNumWords == numWords, + s"Word2VecModel requires $expectedNumWords words, but got $numWords") + model + case _ => throw new Exception( + s"Word2VecModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 52278690dbd89..98a98a7599bcb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -21,6 +21,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + class Word2VecSuite extends FunSuite with MLlibTestSparkContext { // TODO: add more tests @@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { assert(syms(0)._1 == "taiwan") assert(syms(1)._1 == "japan") } + + test("model load / save") { + + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + + } } From 37326079d818fdb140415a65653767d997613dac Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 31 Mar 2015 16:18:39 -0700 Subject: [PATCH 099/171] [SPARK-6614] OutputCommitCoordinator should clear authorized committer only after authorized committer fails, not after any failure In OutputCommitCoordinator, there is some logic to clear the authorized committer's lock on committing in case that task fails. However, it looks like the current code also clears this lock if other non-authorized tasks fail, which is an obvious bug. In theory, it's possible that this could allow a new committer to start, run to completion, and commit output before the authorized committer finished, but it's unlikely that this race occurs often in practice due to the complex combination of failure and timing conditions that would be required to expose it. This patch addresses this issue and adds a regression test. Thanks to aarondav for spotting this issue. Author: Josh Rosen Closes #5276 from JoshRosen/SPARK-6614 and squashes the following commits: d532ba7 [Josh Rosen] Check whether failed task was authorized committer cbb3784 [Josh Rosen] Add regression test for SPARK-6614 --- .../scheduler/OutputCommitCoordinator.scala | 8 +++--- .../OutputCommitCoordinatorSuite.scala | 25 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 17055e2f22d0d..9e29fd13821dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -113,9 +113,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { logInfo( s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") case otherReason => - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") - authorizedCommitters.remove(partition) + if (authorizedCommitters.get(partition).exists(_ == attempt)) { + logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + + s" clearing lock") + authorizedCommitters.remove(partition) + } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index c8c957856247a..cf97707946706 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -161,6 +161,31 @@ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { } assert(tempDir.list().size === 0) } + + test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { + val stage: Int = 1 + val partition: Long = 2 + val authorizedCommitter: Long = 3 + val nonAuthorizedCommitter: Long = 100 + outputCommitCoordinator.stageStart(stage) + assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + // The non-authorized committer fails + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + // New tasks should still not be able to commit because the authorized committer has not failed + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + // The authorized committer now fails, clearing the lock + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + // A new task should now be allowed to become the authorized committer + assert( + outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + // There can only be one authorized committer + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + } } /** From 305abe1e57450f49e3ec4dffb073c5adf17cadef Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 31 Mar 2015 18:31:36 -0700 Subject: [PATCH 100/171] [Doc] Improve Python DataFrame documentation Author: Reynold Xin Closes #5287 from rxin/pyspark-df-doc-cleanup-context and squashes the following commits: 1841b60 [Reynold Xin] Lint. f2007f1 [Reynold Xin] functions and types. bc3b72b [Reynold Xin] More improvements to DataFrame Python doc. ac1d4c0 [Reynold Xin] Bug fix. b163365 [Reynold Xin] Python fix. Added Experimental flag to DataFrameNaFunctions. 608422d [Reynold Xin] [Doc] Cleanup context.py Python docs. --- python/pyspark/sql/__init__.py | 4 +- python/pyspark/sql/context.py | 227 ++++++---------- python/pyspark/sql/dataframe.py | 249 +++++++++--------- python/pyspark/sql/functions.py | 6 +- python/pyspark/sql/types.py | 154 +++-------- .../spark/sql/DataFrameNaFunctions.scala | 3 + 6 files changed, 253 insertions(+), 390 deletions(-) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 9d39e5d9c2449..65abb24eed823 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -16,7 +16,7 @@ # """ -public classes of Spark SQL: +Important classes of Spark SQL and DataFrames: - L{SQLContext} Main entry point for :class:`DataFrame` and SQL functionality. @@ -34,6 +34,8 @@ Methods for handling missing data (null values). - L{functions} List of built-in functions available for :class:`DataFrame`. + - L{types} + List of data types available. """ from pyspark.sql.context import SQLContext, HiveContext diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 80939a1f8ab1e..c2d81ba804110 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -40,9 +40,9 @@ def _monkey_patch_RDD(sqlCtx): def toDF(self, schema=None, sampleRatio=None): """ - Convert current :class:`RDD` into a :class:`DataFrame` + Converts current :class:`RDD` into a :class:`DataFrame` - This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)` + This is a shorthand for ``sqlCtx.createDataFrame(rdd, schema, sampleRatio)`` :param schema: a StructType or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring @@ -56,49 +56,23 @@ def toDF(self, schema=None, sampleRatio=None): RDD.toDF = toDF -class UDFRegistration(object): - """Wrapper for register UDF""" - - def __init__(self, sqlCtx): - self.sqlCtx = sqlCtx - - def register(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - >>> sqlCtx.udf.register("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] - - >>> from pyspark.sql.types import IntegerType - >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] - """ - return self.sqlCtx.registerFunction(name, f, returnType) - - class SQLContext(object): - """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{DataFrame}, register L{DataFrame} as + A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. - """ - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. - - It will add a method called `toDF` to :class:`RDD`, which could be - used to convert an RDD into a DataFrame, it's a shorthand for - :func:`SQLContext.createDataFrame`. + When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`, + which could be used to convert an RDD into a DataFrame, it's a shorthand for + :func:`SQLContext.createDataFrame`. - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + :param sparkContext: The :class:`SparkContext` backing this SQLContext. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. + """ + + def __init__(self, sparkContext, sqlContext=None): + """Creates a new SQLContext. >>> from datetime import datetime >>> sqlCtx = SQLContext(sc) @@ -145,7 +119,7 @@ def getConf(self, key, defaultValue): @property def udf(self): - """Wrapper for register Python function as UDF """ + """Returns a :class:`UDFRegistration` for UDF registration.""" return UDFRegistration(self) def registerFunction(self, name, f, returnType=StringType()): @@ -155,6 +129,10 @@ def registerFunction(self, name, f, returnType=StringType()): When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. + :param name: name of the UDF + :param samplingRatio: lambda function + :param returnType: a :class:`DataType` object + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() [Row(c0=u'4')] @@ -163,6 +141,11 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] + + >>> from pyspark.sql.types import IntegerType + >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) @@ -203,30 +186,7 @@ def _inferSchema(self, rdd, samplingRatio=None): return schema def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - ::note: - Deprecated in 1.3, use :func:`createDataFrame` instead - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ warnings.warn("inferSchema is deprecated, please use createDataFrame instead") @@ -236,27 +196,7 @@ def inferSchema(self, rdd, samplingRatio=None): return self.createDataFrame(rdd, None, samplingRatio) def applySchema(self, rdd, schema): - """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - ::note: - Deprecated in 1.3, use :func:`createDataFrame` instead - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> from pyspark.sql.types import * - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> df = sqlCtx.applySchema(rdd2, schema) - >>> df.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ warnings.warn("applySchema is deprecated, please use createDataFrame instead") @@ -270,25 +210,23 @@ def applySchema(self, rdd, schema): def createDataFrame(self, data, schema=None, samplingRatio=None): """ - Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame. + Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, + list or :class:`pandas.DataFrame`. - `schema` could be :class:`StructType` or a list of column names. + When ``schema`` is a list of column names, the type of each column + will be inferred from ``data``. - When `schema` is a list of column names, the type of each column - will be inferred from `rdd`. + When ``schema`` is ``None``, it will try to infer the schema (column names and types) + from ``data``, which should be an RDD of :class:`Row`, + or :class:`namedtuple`, or :class:`dict`. - When `schema` is None, it will try to infer the column name and type - from `rdd`, which should be an RDD of :class:`Row`, or namedtuple, - or dict. + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of + rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. - If referring needed, `samplingRatio` is used to determined how many - rows will be used to do referring. The first row will be used if - `samplingRatio` is None. - - :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame - :param schema: a StructType or list of names of columns + :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, + :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring - :return: a DataFrame >>> l = [('Alice', 1)] >>> sqlCtx.createDataFrame(l).collect() @@ -373,22 +311,20 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return DataFrame(df, self) - def registerDataFrameAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. - Temporary tables exist only during the lifetime of this instance of - SQLContext. + Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. >>> sqlCtx.registerDataFrameAsTable(df, "table1") """ - if (rdd.__class__ is DataFrame): - df = rdd._jdf - self._ssql_ctx.registerDataFrameAsTable(df, tableName) + if (df.__class__ is DataFrame): + self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) else: raise ValueError("Can only register DataFrame as table") def parquetFile(self, *paths): - """Loads a Parquet file, returning the result as a L{DataFrame}. + """Loads a Parquet file, returning the result as a :class:`DataFrame`. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() @@ -406,15 +342,10 @@ def parquetFile(self, *paths): return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{DataFrame}. + """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -450,13 +381,10 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): return DataFrame(df, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. + """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. >>> df1 = sqlCtx.jsonRDD(json) >>> df1.first() @@ -475,7 +403,6 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): >>> df3 = sqlCtx.jsonRDD(json, schema) >>> df3.first() Row(field2=u'row1', field3=Row(field5=None)) - """ def func(iterator): @@ -496,11 +423,11 @@ def func(iterator): return DataFrame(df, self) def load(self, path=None, source=None, schema=None, **options): - """Returns the dataset in a data source as a DataFrame. + """Returns the dataset in a data source as a :class:`DataFrame`. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Optionally, a schema can be provided as the schema of the returned DataFrame. """ @@ -526,11 +453,11 @@ def createExternalTable(self, tableName, path=None, source=None, It returns the DataFrame associated with the external table. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. - Optionally, a schema can be provided as the schema of the returned DataFrame and + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. """ if path is not None: @@ -551,7 +478,7 @@ def createExternalTable(self, tableName, path=None, source=None, return DataFrame(df, self) def sql(self, sqlQuery): - """Return a L{DataFrame} representing the result of the given query. + """Returns a :class:`DataFrame` representing the result of the given query. >>> sqlCtx.registerDataFrameAsTable(df, "table1") >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") @@ -561,7 +488,7 @@ def sql(self, sqlQuery): return DataFrame(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): - """Returns the specified table as a L{DataFrame}. + """Returns the specified table as a :class:`DataFrame`. >>> sqlCtx.registerDataFrameAsTable(df, "table1") >>> df2 = sqlCtx.table("table1") @@ -571,12 +498,12 @@ def table(self, tableName): return DataFrame(self._ssql_ctx.table(tableName), self) def tables(self, dbName=None): - """Returns a DataFrame containing names of tables in the given database. + """Returns a :class:`DataFrame` containing names of tables in the given database. - If `dbName` is not specified, the current database will be used. + If ``dbName`` is not specified, the current database will be used. - The returned DataFrame has two columns, tableName and isTemporary - (a column with BooleanType indicating if a table is a temporary one or not). + The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` + (a column with :class:`BooleanType` indicating if a table is a temporary one or not). >>> sqlCtx.registerDataFrameAsTable(df, "table1") >>> df2 = sqlCtx.tables() @@ -589,9 +516,9 @@ def tables(self, dbName=None): return DataFrame(self._ssql_ctx.tables(dbName), self) def tableNames(self, dbName=None): - """Returns a list of names of tables in the database `dbName`. + """Returns a list of names of tables in the database ``dbName``. - If `dbName` is not specified, the current database will be used. + If ``dbName`` is not specified, the current database will be used. >>> sqlCtx.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlCtx.tableNames() @@ -618,22 +545,18 @@ def clearCache(self): class HiveContext(SQLContext): - """A variant of Spark SQL that integrates with data stored in Hive. - Configuration for Hive is read from hive-site.xml on the classpath. + Configuration for Hive is read from ``hive-site.xml`` on the classpath. It supports running both SQL and HiveQL commands. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new + :class:`HiveContext` in the JVM, instead we make all calls to this object. """ def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ SQLContext.__init__(self, sparkContext) - if hiveContext: self._scala_HiveContext = hiveContext @@ -652,6 +575,18 @@ def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) +class UDFRegistration(object): + """Wrapper for user-defined function registration.""" + + def __init__(self, sqlCtx): + self.sqlCtx = sqlCtx + + def register(self, name, f, returnType=StringType()): + return self.sqlCtx.registerFunction(name, f, returnType) + + register.__doc__ = SQLContext.registerFunction.__doc__ + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 15508023326cc..c30326ebd133e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -35,8 +35,7 @@ class DataFrame(object): - - """A collection of rows that have the same columns. + """A distributed collection of data grouped into named columns. A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: @@ -69,9 +68,7 @@ def __init__(self, jdf, sql_ctx): @property def rdd(self): - """ - Return the content of the :class:`DataFrame` as an :class:`pyspark.RDD` - of :class:`Row` s. + """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ if not hasattr(self, '_lazy_rdd'): jrdd = self._jdf.javaToPython() @@ -93,7 +90,9 @@ def na(self): return DataFrameNaFunctions(self) def toJSON(self, use_unicode=False): - """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row. + """Converts a :class:`DataFrame` into a :class:`RDD` of string. + + Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() '{"age":2,"name":"Alice"}' @@ -102,10 +101,10 @@ def toJSON(self, use_unicode=False): return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. + """Saves the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a :class:`DataFrame` using the L{SQLContext.parquetFile} method. + a :class:`DataFrame` using :func:`SQLContext.parquetFile`. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() @@ -120,8 +119,8 @@ def saveAsParquetFile(self, path): def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this DataFrame. + The lifetime of this temporary table is tied to the :class:`SQLContext` + that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") >>> df2 = sqlCtx.sql("select * from people") @@ -131,7 +130,7 @@ def registerTempTable(self, name): self._jdf.registerTempTable(name) def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" + """DEPRECATED: use :func:`registerTempTable` instead""" warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) self.registerTempTable(name) @@ -162,22 +161,19 @@ def _java_save_mode(self, mode): return jmode def saveAsTable(self, tableName, source=None, mode="error", **options): - """Saves the contents of the :class:`DataFrame` to a data source as a table. + """Saves the contents of this :class:`DataFrame` to a data source as a table. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Additionally, mode is used to specify the behavior of the saveAsTable operation when table already exists in the data source. There are four modes: - * append: Contents of this :class:`DataFrame` are expected to be appended \ - to existing table. - * overwrite: Data in the existing table is expected to be overwritten by \ - the contents of this DataFrame. - * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of the \ - :class:`DataFrame` and to not change the existing table. + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. """ if source is None: source = self.sql_ctx.getConf("spark.sql.sources.default", @@ -190,18 +186,17 @@ def saveAsTable(self, tableName, source=None, mode="error", **options): def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Additionally, mode is used to specify the behavior of the save operation when data already exists in the data source. There are four modes: - * append: Contents of this :class:`DataFrame` are expected to be appended to existing data. - * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. - * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of \ - the :class:`DataFrame` and to not change the existing data. + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. """ if path is not None: options["path"] = path @@ -215,8 +210,7 @@ def save(self, path=None, source=None, mode="error", **options): @property def schema(self): - """Returns the schema of this :class:`DataFrame` (represented by - a L{StructType}). + """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) @@ -237,11 +231,9 @@ def printSchema(self): print (self._jdf.schema().treeString()) def explain(self, extended=False): - """ - Prints the plans (logical and physical) to the console for - debugging purpose. + """Prints the (logical and physical) plans to the console for debugging purpose. - If extended is False, only prints the physical plan. + :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... @@ -263,15 +255,13 @@ def explain(self, extended=False): print self._jdf.queryExecution().executedPlan().toString() def isLocal(self): - """ - Returns True if the `collect` and `take` methods can be run locally + """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal() def show(self, n=20): - """ - Print the first n rows. + """Prints the first ``n`` rows to the console. >>> df DataFrame[age: int, name: string] @@ -286,11 +276,7 @@ def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the DataFrame, - which supports features such as filter pushdown. + """Returns the number of rows in this :class:`DataFrame`. >>> df.count() 2L @@ -298,10 +284,7 @@ def count(self): return self._jdf.count() def collect(self): - """Return a list that contains all of the rows. - - Each object in the list is a Row, the fields can be accessed as - attributes. + """Returns all the records as a list of :class:`Row`. >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -313,7 +296,7 @@ def collect(self): return [cls(r) for r in rs] def limit(self, num): - """Limit the result count to the number specified. + """Limits the result count to the number specified. >>> df.limit(1).collect() [Row(age=2, name=u'Alice')] @@ -324,10 +307,7 @@ def limit(self, num): return DataFrame(jdf, self.sql_ctx) def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. + """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -335,9 +315,9 @@ def take(self, num): return self.limit(num).collect() def map(self, f): - """ Return a new RDD by applying a function to each Row + """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. - It's a shorthand for df.rdd.map() + This is a shorthand for ``df.rdd.map()``. >>> df.map(lambda p: p.name).collect() [u'Alice', u'Bob'] @@ -345,10 +325,10 @@ def map(self, f): return self.rdd.map(f) def flatMap(self, f): - """ Return a new RDD by first applying a function to all elements of this, + """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. - It's a shorthand for df.rdd.flatMap() + This is a shorthand for ``df.rdd.flatMap()``. >>> df.flatMap(lambda p: p.name).collect() [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] @@ -356,10 +336,9 @@ def flatMap(self, f): return self.rdd.flatMap(f) def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition. + """Returns a new :class:`RDD` by applying the ``f`` function to each partition. - It's a shorthand for df.rdd.mapPartitions() + This is a shorthand for ``df.rdd.mapPartitions()``. >>> rdd = sc.parallelize([1, 2, 3, 4], 4) >>> def f(iterator): yield 1 @@ -369,10 +348,9 @@ def mapPartitions(self, f, preservesPartitioning=False): return self.rdd.mapPartitions(f, preservesPartitioning) def foreach(self, f): - """ - Applies a function to all rows of this DataFrame. + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. - It's a shorthand for df.rdd.foreach() + This is a shorthand for ``df.rdd.foreach()``. >>> def f(person): ... print person.name @@ -381,10 +359,9 @@ def foreach(self, f): return self.rdd.foreach(f) def foreachPartition(self, f): - """ - Applies a function to each partition of this DataFrame. + """Applies the ``f`` function to each partition of this :class:`DataFrame`. - It's a shorthand for df.rdd.foreachPartition() + This a shorthand for ``df.rdd.foreachPartition()``. >>> def f(people): ... for person in people: @@ -394,14 +371,14 @@ def foreachPartition(self, f): return self.rdd.foreachPartition(f) def cache(self): - """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ Persists with the default storage level (C{MEMORY_ONLY_SER}). """ self.is_cached = True self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - """ Set the storage level to persist its values across operations + """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). @@ -412,7 +389,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): return self def unpersist(self, blocking=True): - """ Mark it as non-persistent, and remove all blocks for it from + """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. """ self.is_cached = False @@ -424,8 +401,7 @@ def unpersist(self, blocking=True): # return DataFrame(rdd, self.sql_ctx) def repartition(self, numPartitions): - """ Return a new :class:`DataFrame` that has exactly `numPartitions` - partitions. + """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. >>> df.repartition(10).rdd.getNumPartitions() 10 @@ -433,8 +409,7 @@ def repartition(self, numPartitions): return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) def distinct(self): - """ - Return a new :class:`DataFrame` containing the distinct rows in this DataFrame. + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. >>> df.distinct().count() 2L @@ -442,8 +417,7 @@ def distinct(self): return DataFrame(self._jdf.distinct(), self.sql_ctx) def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this DataFrame. + """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 97).count() 1L @@ -455,7 +429,7 @@ def sample(self, withReplacement, fraction, seed=None): @property def dtypes(self): - """Return all column names and their data types as a list. + """Returns all column names and their data types as a list. >>> df.dtypes [('age', 'int'), ('name', 'string')] @@ -464,7 +438,7 @@ def dtypes(self): @property def columns(self): - """ Return all column names as a list. + """Returns all column names as a list. >>> df.columns [u'age', u'name'] @@ -472,13 +446,14 @@ def columns(self): return [f.name for f in self.schema.fields] def join(self, other, joinExprs=None, joinType=None): - """ - Join with another :class:`DataFrame`, using the given join expression. - The following performs a full outer join between `df1` and `df2`. + """Joins with another :class:`DataFrame`, using the given join expression. + + The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + :param joinType: str, default 'inner'. + One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] @@ -496,9 +471,9 @@ def join(self, other, joinExprs=None, joinType=None): return DataFrame(jdf, self.sql_ctx) def sort(self, *cols): - """ Return a new :class:`DataFrame` sorted by the specified column(s). + """Returns a new :class:`DataFrame` sorted by the specified column(s). - :param cols: The columns or expressions used for sorting + :param cols: list of :class:`Column` to sort by. >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] @@ -539,7 +514,9 @@ def describe(self, *cols): return DataFrame(jdf, self.sql_ctx) def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. + """ + Returns the first ``n`` rows as a list of :class:`Row`, + or the first :class:`Row` if ``n`` is ``None.`` >>> df.head() Row(age=2, name=u'Alice') @@ -552,7 +529,7 @@ def head(self, n=None): return self.take(n) def first(self): - """ Return the first row. + """Returns the first row as a :class:`Row`. >>> df.first() Row(age=2, name=u'Alice') @@ -560,7 +537,7 @@ def first(self): return self.head() def __getitem__(self, item): - """ Return the column by given name + """Returns the column as a :class:`Column`. >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] @@ -580,7 +557,7 @@ def __getitem__(self, item): raise IndexError("unexpected index: %s" % item) def __getattr__(self, name): - """ Return the column by given name + """Returns the :class:`Column` denoted by ``name``. >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] @@ -591,7 +568,11 @@ def __getattr__(self, name): return Column(jc) def select(self, *cols): - """ Selecting a set of expressions. + """Projects a set of expressions and returns a new :class:`DataFrame`. + + :param cols: list of column names (string) or expressions (:class:`Column`). + If one of the column names is '*', that column is expanded to include all columns + in the current DataFrame. >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -606,9 +587,9 @@ def select(self, *cols): return DataFrame(jdf, self.sql_ctx) def selectExpr(self, *expr): - """ - Selects a set of SQL expressions. This is a variant of - `select` that accepts SQL expressions. + """Projects a set of SQL expressions and returns a new :class:`DataFrame`. + + This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] @@ -618,10 +599,12 @@ def selectExpr(self, *expr): return DataFrame(jdf, self.sql_ctx) def filter(self, condition): - """ Filtering rows using the given condition, which could be - :class:`Column` expression or string of SQL expression. + """Filters rows using the given condition. + + :func:`where` is an alias for :func:`filter`. - where() is an alias for filter(). + :param condition: a :class:`Column` of :class:`types.BooleanType` + or a string of SQL expression. >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] @@ -644,10 +627,13 @@ def filter(self, condition): where = filter def groupBy(self, *cols): - """ Group the :class:`DataFrame` using the specified columns, + """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. + :param cols: list of columns to group by. + Each element should be a column name (string) or an expression (:class:`Column`). + >>> df.groupBy().avg().collect() [Row(AVG(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() @@ -662,7 +648,7 @@ def groupBy(self, *cols): def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups - (shorthand for df.groupBy.agg()). + (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() [Row(MAX(age)=5)] @@ -699,7 +685,7 @@ def subtract(self, other): def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. - This is an alias for `na.drop`. + This is an alias for ``na.drop()``. :param how: 'any' or 'all'. If 'any', drop a row if it contains any nulls. @@ -735,7 +721,7 @@ def dropna(self, how='any', thresh=None, subset=None): return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) def fillna(self, value, subset=None): - """Replace null values, alias for `na.fill`. + """Replace null values, alias for ``na.fill()``. :param value: int, long, float, string, or dict. Value to replace null values with. @@ -790,7 +776,10 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) def withColumn(self, colName, col): - """ Return a new :class:`DataFrame` by adding a column. + """Returns a new :class:`DataFrame` by adding a column. + + :param colName: string, name of the new column. + :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] @@ -798,7 +787,10 @@ def withColumn(self, colName, col): return self.select('*', col.alias(colName)) def withColumnRenamed(self, existing, new): - """ Rename an existing column to a new name + """REturns a new :class:`DataFrame` by renaming an existing column. + + :param existing: string, name of the existing column to rename. + :param col: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] @@ -809,8 +801,9 @@ def withColumnRenamed(self, existing, new): return self.select(*cols) def toPandas(self): - """ - Collect all the rows and return a `pandas.DataFrame`. + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + + This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP age name @@ -823,8 +816,7 @@ def toPandas(self): # Having SchemaRDD for backward compatibility (for docs) class SchemaRDD(DataFrame): - """ - SchemaRDD is deprecated, please use DataFrame + """SchemaRDD is deprecated, please use :class:`DataFrame`. """ @@ -851,10 +843,9 @@ def _api(self, *args): class GroupedData(object): - """ A set of methods for aggregations on a :class:`DataFrame`, - created by DataFrame.groupBy(). + created by :func:`DataFrame.groupBy`. """ def __init__(self, jdf, sql_ctx): @@ -862,14 +853,17 @@ def __init__(self, jdf, sql_ctx): self.sql_ctx = sql_ctx def agg(self, *exprs): - """ Compute aggregates by specifying a map from column name - to aggregate methods. + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. - The available aggregate methods are `avg`, `max`, `min`, - `sum`, `count`. + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. - :param exprs: list or aggregate columns or a map from column - name to aggregate methods. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() @@ -894,7 +888,7 @@ def agg(self, *exprs): @dfapi def count(self): - """ Count the number of rows for each group. + """Counts the number of records for each group. >>> df.groupBy(df.age).count().collect() [Row(age=2, count=1), Row(age=5, count=1)] @@ -902,8 +896,11 @@ def count(self): @df_varargs_api def mean(self, *cols): - """Compute the average value for each numeric columns - for each group. This is an alias for `avg`. + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().mean('age').collect() [Row(AVG(age)=3.5)] @@ -913,8 +910,11 @@ def mean(self, *cols): @df_varargs_api def avg(self, *cols): - """Compute the average value for each numeric columns - for each group. + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().avg('age').collect() [Row(AVG(age)=3.5)] @@ -924,8 +924,7 @@ def avg(self, *cols): @df_varargs_api def max(self, *cols): - """Compute the max value for each numeric columns for - each group. + """Computes the max value for each numeric columns for each group. >>> df.groupBy().max('age').collect() [Row(MAX(age)=5)] @@ -935,8 +934,9 @@ def max(self, *cols): @df_varargs_api def min(self, *cols): - """Compute the min value for each numeric column for - each group. + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().min('age').collect() [Row(MIN(age)=2)] @@ -946,8 +946,9 @@ def min(self, *cols): @df_varargs_api def sum(self, *cols): - """Compute the sum for each numeric columns for each - group. + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().sum('age').collect() [Row(SUM(age)=7)] diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5873f09ae3275..8a478fddf0e95 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -76,7 +76,7 @@ def _(col): def countDistinct(col, *cols): - """ Return a new Column for distinct count of `col` or `cols` + """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() [Row(c=2)] @@ -91,7 +91,7 @@ def countDistinct(col, *cols): def approxCountDistinct(col, rsd=None): - """ Return a new Column for approximate distinct count of `col` + """Returns a new :class:`Column` for approximate distinct count of ``col``. >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() [Row(c=2)] @@ -142,7 +142,7 @@ def __call__(self, *cols): def udf(f, returnType=StringType()): - """Create a user defined function (UDF) + """Creates a :class:`Column` expression representing a user defined function (UDF). >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0169028ccc4eb..45eb8b945dcb0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -33,8 +33,7 @@ class DataType(object): - - """Spark SQL DataType""" + """Base class for data types.""" def __repr__(self): return self.__class__.__name__ @@ -67,7 +66,6 @@ def json(self): # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle class PrimitiveTypeSingleton(type): - """Metaclass for PrimitiveType""" _instances = {} @@ -79,66 +77,45 @@ def __call__(cls): class PrimitiveType(DataType): - """Spark SQL PrimitiveType""" __metaclass__ = PrimitiveTypeSingleton class NullType(PrimitiveType): + """Null type. - """Spark SQL NullType - - The data type representing None, used for the types which has not - been inferred. + The data type representing None, used for the types that cannot be inferred. """ class StringType(PrimitiveType): - - """Spark SQL StringType - - The data type representing string values. + """String data type. """ class BinaryType(PrimitiveType): - - """Spark SQL BinaryType - - The data type representing bytearray values. + """Binary (byte array) data type. """ class BooleanType(PrimitiveType): - - """Spark SQL BooleanType - - The data type representing bool values. + """Boolean data type. """ class DateType(PrimitiveType): - - """Spark SQL DateType - - The data type representing datetime.date values. + """Date (datetime.date) data type. """ class TimestampType(PrimitiveType): - - """Spark SQL TimestampType - - The data type representing datetime.datetime values. + """Timestamp (datetime.datetime) data type. """ class DecimalType(DataType): - - """Spark SQL DecimalType - - The data type representing decimal.Decimal values. + """Decimal (decimal.Decimal) data type. """ def __init__(self, precision=None, scale=None): @@ -166,80 +143,55 @@ def __repr__(self): class DoubleType(PrimitiveType): - - """Spark SQL DoubleType - - The data type representing float values. + """Double data type, representing double precision floats. """ class FloatType(PrimitiveType): - - """Spark SQL FloatType - - The data type representing single precision floating-point values. + """Float data type, representing single precision floats. """ class ByteType(PrimitiveType): - - """Spark SQL ByteType - - The data type representing int values with 1 singed byte. + """Byte data type, i.e. a signed integer in a single byte. """ def simpleString(self): return 'tinyint' class IntegerType(PrimitiveType): - - """Spark SQL IntegerType - - The data type representing int values. + """Int data type, i.e. a signed 32-bit integer. """ def simpleString(self): return 'int' class LongType(PrimitiveType): + """Long data type, i.e. a signed 64-bit integer. - """Spark SQL LongType - - The data type representing long values. If the any value is - beyond the range of [-9223372036854775808, 9223372036854775807], - please use DecimalType. + If the values are beyond the range of [-9223372036854775808, 9223372036854775807], + please use :class:`DecimalType`. """ def simpleString(self): return 'bigint' class ShortType(PrimitiveType): - - """Spark SQL ShortType - - The data type representing int values with 2 signed bytes. + """Short data type, i.e. a signed 16-bit integer. """ def simpleString(self): return 'smallint' class ArrayType(DataType): + """Array data type. - """Spark SQL ArrayType - - The data type representing list values. An ArrayType object - comprises two fields, elementType (a DataType) and containsNull (a bool). - The field of elementType is used to specify the type of array elements. - The field of containsNull is used to specify if the array has None values. - + :param elementType: :class:`DataType` of each element in the array. + :param containsNull: boolean, whether the array can contain null (None) values. """ def __init__(self, elementType, containsNull=True): - """Creates an ArrayType - - :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains None values. - + """ >>> ArrayType(StringType()) == ArrayType(StringType(), True) True >>> ArrayType(StringType(), False) == ArrayType(StringType()) @@ -268,29 +220,17 @@ def fromJson(cls, json): class MapType(DataType): + """Map data type. - """Spark SQL MapType - - The data type representing dict values. A MapType object comprises - three fields, keyType (a DataType), valueType (a DataType) and - valueContainsNull (a bool). - - The field of keyType is used to specify the type of keys in the map. - The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this - map has None values. - - For values of a MapType column, keys are not allowed to have None values. + :param keyType: :class:`DataType` of the keys in the map. + :param valueType: :class:`DataType` of the values in the map. + :param valueContainsNull: indicates whether values can contain null (None) values. + Keys in a map data type are not allowed to be null (None). """ def __init__(self, keyType, valueType, valueContainsNull=True): - """Creates a MapType - :param keyType: the data type of keys. - :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains - null values. - + """ >>> (MapType(StringType(), IntegerType()) ... == MapType(StringType(), IntegerType(), True)) True @@ -325,30 +265,16 @@ def fromJson(cls, json): class StructField(DataType): + """A field in :class:`StructType`. - """Spark SQL StructField - - Represents a field in a StructType. - A StructField object comprises three fields, name (a string), - dataType (a DataType) and nullable (a bool). The field of name - is the name of a StructField. The field of dataType specifies - the data type of a StructField. - - The field of nullable specifies if values of a StructField can - contain None values. - + :param name: string, name of the field. + :param dataType: :class:`DataType` of the field. + :param nullable: boolean, whether the field can be null (None) or not. + :param metadata: a dict from string to simple type that can be serialized to JSON automatically """ def __init__(self, name, dataType, nullable=True, metadata=None): - """Creates a StructField - :param name: the name of this field. - :param dataType: the data type of this field. - :param nullable: indicates whether values of this field - can be null. - :param metadata: metadata of this field, which is a map from string - to simple type that can be serialized to JSON - automatically - + """ >>> (StructField("f1", StringType(), True) ... == StructField("f1", StringType(), True)) True @@ -384,17 +310,13 @@ def fromJson(cls, json): class StructType(DataType): + """Struct type, consisting of a list of :class:`StructField`. - """Spark SQL StructType - - The data type representing rows. - A StructType object comprises a list of L{StructField}. - + This is the data type representing a :class:`Row`. """ def __init__(self, fields): - """Creates a StructType - + """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 @@ -425,9 +347,9 @@ def fromJson(cls, json): class UserDefinedType(DataType): - """ + """User-defined type (UDT). + .. note:: WARN: Spark Internal Use Only - SQL User-Defined Type (UDT). """ @classmethod diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 3a3dc70f7285c..bf3c3fe876873 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -21,14 +21,17 @@ import java.{lang => jl} import scala.collection.JavaConversions._ +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** + * :: Experimental :: * Functionality for working with missing data in [[DataFrame]]s. */ +@Experimental final class DataFrameNaFunctions private[sql](df: DataFrame) { /** From ff1915e12edc4d23e0b4e88933429c2d3470f3d9 Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Wed, 1 Apr 2015 11:09:00 +0100 Subject: [PATCH 101/171] [SPARK-4655][Core] Split Stage into ShuffleMapStage and ResultStage subclasses Hi all - this patch changes the Stage class to an abstract class and introduces two new classes that extend it: ShuffleMapStage and ResultStage - with the goal of increasing readability of the DAGScheduler class. Their usage is updated within DAGScheduler. Author: Ilya Ganelin Author: Ilya Ganelin Closes #4708 from ilganeli/SPARK-4655 and squashes the following commits: c248924 [Ilya Ganelin] Merge branch 'SPARK-4655' of github.com:ilganeli/spark into SPARK-4655 d930385 [Ilya Ganelin] Fixed merge conflict from a9a765f [Ilya Ganelin] Update DAGScheduler.scala c03563c [Ilya Ganelin] Minor fixeS c39e971 [Ilya Ganelin] Added return typing for public methods 845bc87 [Ilya Ganelin] Merge branch 'SPARK-4655' of github.com:ilganeli/spark into SPARK-4655 e8031d8 [Ilya Ganelin] Minor string fixes 4ec53ac [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-4655 c004f62 [Ilya Ganelin] Update DAGScheduler.scala a2cb03f [Ilya Ganelin] [SPARK-4655] Replaced usages of Nil and eliminated some code reuse 3d5cf20 [Ilya Ganelin] [SPARK-4655] Moved mima exclude to 1.4 6912c55 [Ilya Ganelin] Resolved merge conflict 4bff208 [Ilya Ganelin] Minor stylistic fixes c6fffbb [Ilya Ganelin] newline 41402ad [Ilya Ganelin] Style fixes 02c6981 [Ilya Ganelin] Merge branch 'SPARK-4655' of github.com:ilganeli/spark into SPARK-4655 c755a09 [Ilya Ganelin] Some more stylistic updates and minor refactoring b6257a0 [Ilya Ganelin] Update MimaExcludes.scala 0f0c624 [Ilya Ganelin] Fixed merge conflict 2eba262 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-4655 6b43d7b [Ilya Ganelin] Got rid of some spaces 6f1a5db [Ilya Ganelin] Revert "More minor formatting and refactoring" 1b3471b [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-4655 c9288e2 [Ilya Ganelin] More minor formatting and refactoring d548caf [Ilya Ganelin] Formatting fix c3ae5c2 [Ilya Ganelin] Explicit typing 0dacaf3 [Ilya Ganelin] Got rid of stale import 6da3a71 [Ilya Ganelin] Trailing whitespace b85c5fe [Ilya Ganelin] Added minor fixes a57dfcd [Ilya Ganelin] Added MiMA exclusion to get around binary compatibility check 83ed849 [Ilya Ganelin] moved braces for consistency 96dd161 [Ilya Ganelin] Fixed minor style error cfd6f10 [Ilya Ganelin] Updated DAGScheduler to use new ResultStage and ShuffleMapStage classes 83494e9 [Ilya Ganelin] Added new Stage classes --- .../apache/spark/scheduler/ActiveJob.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 297 ++++++++++-------- .../apache/spark/scheduler/ResultStage.scala | 40 +++ .../spark/scheduler/ShuffleMapStage.scala | 84 +++++ .../org/apache/spark/scheduler/Stage.scala | 65 +--- project/MimaExcludes.scala | 6 +- 6 files changed, 298 insertions(+), 196 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index b755d8fb15757..50a69379412d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.CallSite */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: Stage, + val finalStage: ResultStage, val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], val callSite: CallSite, diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b405bd3338e7c..d35b4f9dbaf88 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -83,7 +83,7 @@ class DAGScheduler( private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new HashMap[Int, Stage] - private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] + private[scheduler] val shuffleToMapStage = new HashMap[Int, ShuffleMapStage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] // Stages we need to run whose parents aren't done @@ -150,7 +150,7 @@ class DAGScheduler( result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { + taskMetrics: TaskMetrics): Unit = { eventProcessLoop.post( CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) } @@ -173,18 +173,18 @@ class DAGScheduler( } // Called by TaskScheduler when an executor fails. - def executorLost(execId: String) { + def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added - def executorAdded(execId: String, host: String) { + def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String) { + def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { eventProcessLoop.post(TaskSetFailed(taskSet, reason)) } @@ -210,40 +210,65 @@ class DAGScheduler( * The jobId value passed in will be used if the stage doesn't already exist with * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { + private def getShuffleMapStage( + shuffleDep: ShuffleDependency[_, _, _], + jobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies registerShuffleDependencies(shuffleDep, jobId) // Then register current shuffleDep - val stage = - newOrUsedStage( - shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, - shuffleDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - + stage } } /** - * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation - * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided - * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage - * directly. + * Helper function to eliminate some code re-use when creating new stages. */ - private def newStage( + private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, jobId) + val id = nextStageId.getAndIncrement() + (parentStages, id) + } + + /** + * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in + * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * Production of shuffle map stages should always use newOrUsedShuffleStage, not + * newShuffleMapStage directly. + */ + private def newShuffleMapStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_, _, _]], + shuffleDep: ShuffleDependency[_, _, _], jobId: Int, - callSite: CallSite) - : Stage = - { - val parentStages = getParentStages(rdd, jobId) - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite) + callSite: CallSite): ShuffleMapStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, + jobId, callSite, shuffleDep) + + stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) + stage + } + + /** + * Create a ResultStage -- either directly for use as a result stage, or as part of the + * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated + * with the provided jobId. + */ + private def newResultStage( + rdd: RDD[_], + numTasks: Int, + jobId: Int, + callSite: CallSite): ResultStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) + stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -255,20 +280,17 @@ class DAGScheduler( * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ - private def newOrUsedStage( - rdd: RDD[_], - numTasks: Int, + private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, - callSite: CallSite) - : Stage = - { - val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + jobId: Int): ShuffleMapStage = { + val rdd = shuffleDep.rdd + val numTasks = rdd.partitions.size + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } stage.numAvailableOutputs = locs.count(_ != null) } else { @@ -306,26 +328,23 @@ class DAGScheduler( } } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents.toList } - // Find ancestor missing shuffle dependencies and register into shuffleToMapStage - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (!parentsWithNoMapStage.isEmpty) { + while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = - newOrUsedStage( - currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, - currentShufDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(currentShufDep, jobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } - // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] @@ -351,7 +370,7 @@ class DAGScheduler( } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents @@ -382,7 +401,7 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } missing.toList @@ -392,7 +411,7 @@ class DAGScheduler( * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. */ - private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { def updateJobIdStageIdMapsList(stages: List[Stage]) { if (stages.nonEmpty) { val s = stages.head @@ -412,7 +431,7 @@ class DAGScheduler( * * @param job The job whose state to cleanup. */ - private def cleanupStateForJobAndIndependentStages(job: ActiveJob) { + private def cleanupStateForJobAndIndependentStages(job: ActiveJob): Unit = { val registeredStages = jobIdToStageIds.get(job.jobId) if (registeredStages.isEmpty || registeredStages.get.isEmpty) { logError("No stages registered for job " + job.jobId) @@ -474,8 +493,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = - { + properties: Properties = null): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -504,15 +522,13 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - { + properties: Properties = null): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { - case JobSucceeded => { + case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - } case JobFailed(exception: Exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) @@ -526,9 +542,7 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null) - : PartialResult[R] = - { + properties: Properties = null): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray @@ -541,12 +555,12 @@ class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int) { + def cancelJob(jobId: Int): Unit = { logInfo("Asked to cancel job " + jobId) eventProcessLoop.post(JobCancelled(jobId)) } - def cancelJobGroup(groupId: String) { + def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) } @@ -554,7 +568,7 @@ class DAGScheduler( /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs() { + def cancelAllJobs(): Unit = { eventProcessLoop.post(AllJobsCancelled) } @@ -722,13 +736,12 @@ class DAGScheduler( allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) - { - var finalStage: Stage = null + properties: Properties = null) { + var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite) + finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -773,7 +786,7 @@ class DAGScheduler( if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) - if (missing == Nil) { + if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage, jobId.get) } else { @@ -794,13 +807,15 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() + // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = { - if (stage.isShuffleMap) { - (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil) - } else { - val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + stage match { + case stage: ShuffleMapStage => + (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) + case stage: ResultStage => + val job = stage.resultOfJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) } } @@ -830,18 +845,21 @@ class DAGScheduler( try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = - if (stage.isShuffleMap) { - closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() - } else { - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() - } + val taskBinaryBytes: Array[Byte] = stage match { + case stage: ShuffleMapStage => + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + case stage: ResultStage => + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) runningStages -= stage + + // Abort execution return case NonFatal(e) => abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") @@ -849,20 +867,22 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = if (stage.isShuffleMap) { - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } - } else { - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + val tasks: Seq[Task[_]] = stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } + + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } } if (tasks.size > 0) { @@ -877,8 +897,17 @@ class DAGScheduler( // SparkListenerStageCompleted here in case there are no tasks to run. outputCommitCoordinator.stageEnd(stage.id) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - logDebug("Stage " + stage + " is actually done; %b %d %d".format( - stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) + + val debugString = stage match { + case stage: ShuffleMapStage => + s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})" + case stage : ResultStage => + s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + } + logDebug(debugString) runningStages -= stage } } @@ -968,7 +997,10 @@ class DAGScheduler( stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => - stage.resultOfJob match { + // Cast to ResultStage here because it's part of the ResultTask + // TODO Refactor this out to a function that accepts a ResultStage + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { updateAccumulators(event) @@ -976,7 +1008,7 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - markStageAsFinished(stage) + markStageAsFinished(resultStage) cleanupStateForJobAndIndependentStages(job) listenerBus.post( SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) @@ -988,7 +1020,7 @@ class DAGScheduler( job.listener.taskSucceeded(rt.outputId, event.result) } catch { case e: Exception => - // TODO: Perhaps we want to mark the stage as failed? + // TODO: Perhaps we want to mark the resultStage as failed? job.listener.jobFailed(new SparkDriverExecutionException(e)) } } @@ -997,6 +1029,7 @@ class DAGScheduler( } case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId @@ -1004,50 +1037,54 @@ class DAGScheduler( if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { - stage.addOutputLoc(smt.partitionId, status) + shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) { - markStageAsFinished(stage) + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - if (stage.shuffleDep.isDefined) { - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, - changeEpoch = true) - } + + // We supply true to increment the epoch number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // epoch incremented to refetch them. + // TODO: Only increment the epoch number if this is not the first time + // we registered these map outputs. + mapOutputTracker.registerMapOutputs( + shuffleStage.shuffleDep.shuffleId, + shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = true) + clearCacheLocs() - if (stage.outputLocs.exists(_ == Nil)) { - // Some tasks had failed; let's resubmit this stage + if (shuffleStage.outputLocs.contains(Nil)) { + // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " (" + stage.name + + logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + - stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) - submitStage(stage) + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) + .map(_._2).mkString(", ")) + submitStage(shuffleStage) } else { val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waitingStages) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) + for (shuffleStage <- waitingStages) { + logInfo("Missing parents for " + shuffleStage + ": " + + getMissingParentStages(shuffleStage)) } - for (stage <- waitingStages if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage + for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) + { + newlyRunnable += shuffleStage } waitingStages --= newlyRunnable runningStages ++= newlyRunnable for { - stage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(stage) + shuffleStage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(shuffleStage) } { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage, jobId) + logInfo("Submitting " + shuffleStage + " (" + + shuffleStage.rdd + "), which is now runnable") + submitMissingTasks(shuffleStage, jobId) } } } @@ -1204,9 +1241,7 @@ class DAGScheduler( } } - /** - * Fails a job and all stages that are only used by that job, and cleans up relevant state. - */ + /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { val error = new SparkException(failureReason) var ableToCancelStages = true @@ -1254,9 +1289,7 @@ class DAGScheduler( } } - /** - * Return true if one of stage's ancestors is target. - */ + /** Return true if one of stage's ancestors is target. */ private def stageDependsOn(stage: Stage, target: Stage): Boolean = { if (stage == target) { return true @@ -1282,7 +1315,7 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } visitedRdds.contains(target.rdd) @@ -1312,9 +1345,7 @@ class DAGScheduler( private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]) - : Seq[TaskLocation] = - { + visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 if (!visited.add((rdd,partition))) { @@ -1323,12 +1354,12 @@ class DAGScheduler( } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) - if (!cached.isEmpty) { + if (cached.nonEmpty) { return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList - if (!rddPrefs.isEmpty) { + if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } // If the RDD has narrow dependencies, pick the first partition of the first narrow dep @@ -1412,7 +1443,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.sc.stop() } - override def onStop() { + override def onStop(): Unit = { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala new file mode 100644 index 0000000000000..c0f3d5a13d623 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.CallSite + +/** + * The ResultStage represents the final stage in a job. + */ +private[spark] class ResultStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + jobId: Int, + callSite: CallSite) + extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + + // The active job for this result stage. Will be empty if the job has already finished + // (e.g., because the job was cancelled). + var resultOfJob: Option[ActiveJob] = None + + override def toString: String = "ResultStage " + id +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala new file mode 100644 index 0000000000000..d02210743484c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.ShuffleDependency +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.CallSite + +/** + * The ShuffleMapStage represents the intermediate stages in a job. + */ +private[spark] class ShuffleMapStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + jobId: Int, + callSite: CallSite, + val shuffleDep: ShuffleDependency[_, _, _]) + extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + + override def toString: String = "ShuffleMapStage " + id + + var numAvailableOutputs: Long = 0 + + def isAvailable: Boolean = numAvailableOutputs == numPartitions + + val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + + def addOutputLoc(partition: Int, status: MapStatus): Unit = { + val prevList = outputLocs(partition) + outputLocs(partition) = status :: prevList + if (prevList == Nil) { + numAvailableOutputs += 1 + } + } + + def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location == bmAddress) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + numAvailableOutputs -= 1 + } + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + var becameUnavailable = false + for (partition <- 0 until numPartitions) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location.executorId == execId) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + becameUnavailable = true + numAvailableOutputs -= 1 + } + } + if (becameUnavailable) { + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 4cbc6e84a6bdd..5d0ddb8377c33 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -21,7 +21,6 @@ import scala.collection.mutable.HashSet import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -47,29 +46,23 @@ import org.apache.spark.util.CallSite * be updated for each attempt. * */ -private[spark] class Stage( +private[spark] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, val callSite: CallSite) extends Logging { - val isShuffleMap = shuffleDep.isDefined val numPartitions = rdd.partitions.size - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - var numAvailableOutputs = 0 /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - /** For stages that are the final (consists of only ResultTasks), link to the ActiveJob. */ - var resultOfJob: Option[ActiveJob] = None var pendingTasks = new HashSet[Task[_]] - private var nextAttemptId = 0 + private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm @@ -77,53 +70,6 @@ private[spark] class Stage( /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ var latestInfo: StageInfo = StageInfo.fromStage(this) - def isAvailable: Boolean = { - if (!isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, status: MapStatus) { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String) { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) - } - } - /** Return a new attempt id, starting with 0. */ def newAttemptId(): Int = { val id = nextAttemptId @@ -133,11 +79,8 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - override def toString: String = "Stage " + id - - override def hashCode(): Int = id - - override def equals(other: Any): Boolean = other match { + override final def hashCode(): Int = id + override final def equals(other: Any): Boolean = other match { case stage: Stage => stage != null && stage.id == id case _ => false } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index efd59a7e5470f..54500f7c2701f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,7 +54,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorActor") ) ++ Seq( - // SPARK-6510 Add a Graph#minus method acting as Set#difference + // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // the stage class is defined as private[spark] + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") ) From 412262346f6f48e641bd6899c703efa31aeaba1e Mon Sep 17 00:00:00 2001 From: Florian Verhein Date: Wed, 1 Apr 2015 11:10:43 +0100 Subject: [PATCH 102/171] [EC2] [SPARK-6600] Open ports in ec2/spark_ec2.py to allow HDFS NFS gateway Authorizes incoming access to master on the ports required to use the hadoop hdfs nfs gateway from outside the cluster. Author: Florian Verhein Closes #5257 from florianverhein/master and squashes the following commits: 72a586a [Florian Verhein] [EC2] [SPARK-6600] initial impl --- ec2/spark_ec2.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index c467cd08ed742..5507a9c5a4733 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -456,6 +456,13 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # HDFS NFS gateway requires 111,2049,4242 for tcp & udp + master_group.authorize('tcp', 111, 111, authorized_address) + master_group.authorize('udp', 111, 111, authorized_address) + master_group.authorize('tcp', 2049, 2049, authorized_address) + master_group.authorize('udp', 2049, 2049, authorized_address) + master_group.authorize('tcp', 4242, 4242, authorized_address) + master_group.authorize('udp', 4242, 4242, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created From d824c11c9fe8af1ca1d7c694b2fb81289eb83f97 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 1 Apr 2015 11:11:56 +0100 Subject: [PATCH 103/171] [SPARK-6597][Minor] Replace `input:checkbox` with `input[type="checkbox"]` in additional-metrics.js In additional-metrics.js, there are some selector notation like `input:checkbox` but JQuery's official document says `input[type="checkbox"]` is better. https://api.jquery.com/checkbox-selector/ Author: Kousuke Saruta Closes #5254 from sarutak/SPARK-6597 and squashes the following commits: a253bc4 [Kousuke Saruta] Replaced input:checkbox with input[type="checkbox"] --- .../org/apache/spark/ui/static/additional-metrics.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 14ba37d7c9bd9..013db8df9b363 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -30,7 +30,7 @@ $(function() { stripeSummaryTable(); - $("input:checkbox").click(function() { + $('input[type="checkbox"]').click(function() { var column = "table ." + $(this).attr("name"); $(column).toggle(); stripeSummaryTable(); @@ -39,15 +39,15 @@ $(function() { $("#select-all-metrics").click(function() { if (this.checked) { // Toggle all un-checked options. - $('input:checkbox:not(:checked)').trigger('click'); + $('input[type="checkbox"]:not(:checked)').trigger('click'); } else { // Toggle all checked options. - $('input:checkbox:checked').trigger('click'); + $('input[type="checkbox"]:checked').trigger('click'); } }); // Trigger a click on the checkbox if a user clicks the label next to it. $("span.additional-metric-title").click(function() { - $(this).parent().find('input:checkbox').trigger('click'); + $(this).parent().find('input[type="checkbox"]').trigger('click'); }); }); From 0358b08db85b3ee4ae70834626e7a42311bcc635 Mon Sep 17 00:00:00 2001 From: jayson Date: Wed, 1 Apr 2015 11:12:55 +0100 Subject: [PATCH 104/171] SPARK-6626 [DOCS]: Corrected Scala:TwitterUtils parameters Per Sean Owen's request, here is the update call for TwitterUtils using Scala :) Author: jayson Closes #5295 from JaysonSunshine/master and squashes the following commits: df1d056 [jayson] Corrected Scala:TwitterUtils parameters --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 6d6229625f3f9..262512a639046 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -704,7 +704,7 @@ create a DStream using data from Twitter's stream of tweets, you have to do the {% highlight scala %} import org.apache.spark.streaming.twitter._ -TwitterUtils.createStream(ssc) +TwitterUtils.createStream(ssc, None) {% endhighlight %}
From d36c5fca7b9227c4c6e1b0c1455269b5fd8d4852 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 1 Apr 2015 21:34:45 +0800 Subject: [PATCH 105/171] [SPARK-6608] [SQL] Makes DataFrame.rdd a lazy val Before 1.3.0, `SchemaRDD.id` works as a unique identifier of each `SchemaRDD`. In 1.3.0, unlike `SchemaRDD`, `DataFrame` is no longer an RDD, and `DataFrame.rdd` is actually a function which always returns a new RDD instance. Making `DataFrame.rdd` a lazy val should bring the unique identifier back. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5265) Author: Cheng Lian Closes #5265 from liancheng/spark-6608 and squashes the following commits: 7500968 [Cheng Lian] Updates javadoc 7f37d21 [Cheng Lian] Makes DataFrame.rdd a lazy val --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5cd0a18ff688c..19cfa15f27b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -952,10 +952,12 @@ class DataFrame private[sql]( ///////////////////////////////////////////////////////////////////////////// /** - * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. + * Represents the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. Note that the RDD is + * memoized. Once called, it won't change even if you change any query planning related Spark SQL + * configurations (e.g. `spark.sql.shuffle.partitions`). * @group rdd */ - def rdd: RDD[Row] = { + lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) From ee11be258251adf900680927ba200bf46512cc04 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 1 Apr 2015 16:26:54 +0100 Subject: [PATCH 106/171] SPARK-6433 hive tests to import spark-sql test JAR for QueryTest access 1. Test JARs are built & published 1. log4j.resources is explicitly excluded. Without this, downstream test run logging depends on the order the JARs are listed/loaded 1. sql/hive pulls in spark-sql &...spark-catalyst for its test runs 1. The copied in test classes were rm'd, and a test edited to remove its now duplicate assert method 1. Spark streaming is now build with the same plugin/phase as the rest, but its shade plugin declaration is kept in (so different from the rest of the test plugins). Due to (#2), this means the test JAR no longer includes its log4j file. Outstanding issues: * should the JARs be shaded? `spark-streaming-test.jar` does, but given these are test jars for developers only, especially in the same spark source tree, it's hard to justify. * `maven-jar-plugin` v 2.6 was explicitly selected; without this the apache-1.4 parent template JAR version (2.4) chosen. * Are there any other resources to exclude? Author: Steve Loughran Closes #5119 from steveloughran/stevel/patches/SPARK-6433-test-jars and squashes the following commits: 81ceb01 [Steve Loughran] SPARK-6433 add a clearer comment explaining what the plugin is doing & why a6dca33 [Steve Loughran] SPARK-6433 : pull configuration section form archive plugin c2b5f89 [Steve Loughran] SPARK-6433 omit "jar" goal from jar plugin fdac51b [Steve Loughran] SPARK-6433 -002; indentation & delegate plugin version to parent 650f442 [Steve Loughran] SPARK-6433 patch 001: test JARs are built; sql/hive pulls in spark-sql & spark-catalyst for its test runs --- pom.xml | 20 +++ sql/hive/pom.xml | 14 ++ .../org/apache/spark/sql/QueryTest.scala | 140 ------------------ .../spark/sql/catalyst/plans/PlanTest.scala | 57 ------- .../spark/sql/hive/CachedTableSuite.scala | 15 -- streaming/pom.xml | 28 ---- 6 files changed, 34 insertions(+), 240 deletions(-) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala diff --git a/pom.xml b/pom.xml index 3eb3da2cd8af3..42bd926a2fcb8 100644 --- a/pom.xml +++ b/pom.xml @@ -1265,6 +1265,7 @@ create-source-jar jar-no-fork + test-jar-no-fork @@ -1473,6 +1474,25 @@ org.scalatest scalatest-maven-plugin + + + org.apache.maven.plugins + maven-jar-plugin + + + prepare-test-jar + prepare-package + + test-jar + + + + log4j.properties + + + + + diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index a9816f6c38cd2..04440076a26a3 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -89,6 +89,20 @@ junit test + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + ${project.version} + test + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala deleted file mode 100644 index 0270e63557963..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConversions._ - -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.util._ - - -/** - * *** DUPLICATED FROM sql/core. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class QueryTest extends PlanTest { - - /** - * Runs the plan and makes sure the answer contains all of the keywords, or the - * none of keywords are listed in the answer - * @param rdd the [[DataFrame]] to be executed - * @param exists true for make sure the keywords are listed in the output, otherwise - * to make sure none of the keyword are not listed in the output - * @param keywords keyword in string array - */ - def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { - val outputs = rdd.collect().map(_.mkString).mkString - for (key <- keywords) { - if (exists) { - assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") - } else { - assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") - } - } - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(rdd, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) - } - - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } -} - -object QueryTest { - /** - * Runs the plan and makes sure the answer matches the expected result. - * If there was exception during the execution or the contents of the DataFrame does not - * match the expected result, an error message will be returned. Otherwise, a [[None]] will - * be returned. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case o => o - }) - } - if (!isSorted) converted.sortBy(_.toString) else converted - } - val sparkAnswer = try rdd.collect().toSeq catch { - case e: Exception => - val errorMessage = - s""" - |Exception thrown while executing query: - |${rdd.queryExecution} - |== Exception == - |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - |Results do not match for query: - |${rdd.logicalPlan} - |== Analyzed Plan == - |${rdd.queryExecution.analyzed} - |== Physical Plan == - |${rdd.queryExecution.executedPlan} - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin - return Some(errorMessage) - } - - return None - } - - def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(rdd, expectedAnswer.toSeq) match { - case Some(errorMessage) => errorMessage - case None => null - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala deleted file mode 100644 index 98f1c0e69e29d..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ -import org.scalatest.FunSuite - -/** - * *** DUPLICATED FROM sql/catalyst/plans. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class PlanTest extends FunSuite { - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - plan transformAllExpressions { - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) - } - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeExprIds(plan1) - val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 221a0c263d36c..c188264072a84 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -24,21 +24,6 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { - /** - * Throws a test failed exception when the number of cached tables differs from the expected - * number. - */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan diff --git a/streaming/pom.xml b/streaming/pom.xml index 23a8358d45c2a..5ca55a4f680bb 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -97,34 +97,6 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-jar-plugin - - - - test-jar - - - - test-jar-on-test-compile - test-compile - - test-jar - - - - - org.apache.maven.plugins maven-shade-plugin From 2275acce7ba5fac83c58554d7ee9f4c7f3e866cf Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 1 Apr 2015 13:29:04 -0700 Subject: [PATCH 107/171] [SPARK-6651][MLLIB] delegate dense vector arithmetics to the underlying numpy array Users should be able to use numpy operators directly on dense vectors. davies atalwalkar Author: Xiangrui Meng Closes #5312 from mengxr/SPARK-6651 and squashes the following commits: e665c5c [Xiangrui Meng] wrap the result in a dense vector 23dfca3 [Xiangrui Meng] delegate dense vector arithmetics to the underlying numpy array --- python/pyspark/mllib/linalg.py | 38 +++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f5aad28afda0f..8b791ff6a7877 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -173,7 +173,24 @@ def toArray(self): class DenseVector(Vector): """ - A dense vector represented by a value array. + A dense vector represented by a value array. We use numpy array for + storage and arithmetics will be delegated to the underlying numpy + array. + + >>> v = Vectors.dense([1.0, 2.0]) + >>> u = Vectors.dense([3.0, 4.0]) + >>> v + u + DenseVector([4.0, 6.0]) + >>> 2 - v + DenseVector([1.0, 0.0]) + >>> v / 2 + DenseVector([0.5, 1.0]) + >>> v * u + DenseVector([3.0, 8.0]) + >>> u / v + DenseVector([3.0, 2.0]) + >>> u % 2 + DenseVector([1.0, 0.0]) """ def __init__(self, ar): if isinstance(ar, basestring): @@ -292,6 +309,25 @@ def __ne__(self, other): def __getattr__(self, item): return getattr(self.array, item) + def _delegate(op): + def func(self, other): + if isinstance(other, DenseVector): + other = other.array + return DenseVector(getattr(self.array, op)(other)) + return func + + __neg__ = _delegate("__neg__") + __add__ = _delegate("__add__") + __sub__ = _delegate("__sub__") + __mul__ = _delegate("__mul__") + __div__ = _delegate("__div__") + __mod__ = _delegate("__mod__") + __radd__ = _delegate("__radd__") + __rsub__ = _delegate("__rsub__") + __rmul__ = _delegate("__rmul__") + __rdiv__ = _delegate("__rdiv__") + __rmod__ = _delegate("__rmod__") + class SparseVector(Vector): """ From fb25e8c7f45b4f96561e3f7434a0f4dfce8ddefe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 1 Apr 2015 15:15:47 -0700 Subject: [PATCH 108/171] [SPARK-6657] [Python] [Docs] fixed python doc build warnings fixed python doc build warnings CC whomever wants to review: rxin mengxr davies Author: Joseph K. Bradley Closes #5317 from jkbradley/python-doc-warnings and squashes the following commits: 4cd43c2 [Joseph K. Bradley] fixed python doc build warnings --- python/docs/pyspark.streaming.rst | 2 +- python/pyspark/mllib/tree.py | 26 ++++++++++---------------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index 7890d9dcaac21..50822c93faba1 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -10,7 +10,7 @@ Module contents :show-inheritance: pyspark.streaming.kafka module ----------------------------- +------------------------------ .. automodule:: pyspark.streaming.kafka :members: :undoc-members: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index bf288d76447bd..a7a4d2aaf855b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -286,21 +286,18 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", - "onethird". - If "auto" is set, this parameter is set based on - numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - :param impurity: Criterion used for information gain - calculation. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + :param impurity: Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy". :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 4) :param maxBins: maximum number of bins used for splitting features - (default: 100) + (default: 100) :param seed: Random seed for bootstrapping and choosing feature subsets. :return: RandomForestModel that can be used for prediction @@ -365,13 +362,10 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", - "onethird". - If "auto" is set, this parameter is set based on - numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for - regression. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. :param impurity: Criterion used for information gain calculation. Supported values: "variance". From f084c5de14eb10a6aba82a39e03e7877926ebb9e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 1 Apr 2015 16:06:11 -0700 Subject: [PATCH 109/171] [SPARK-6578] [core] Fix thread-safety issue in outbound path of network library. While the inbound path of a netty pipeline is thread-safe, the outbound path is not. That means that multiple threads can compete to write messages to the next stage of the pipeline. The network library sometimes breaks a single RPC message into multiple buffers internally to avoid copying data (see MessageEncoder). This can result in the following scenario (where "FxBy" means "frame x, buffer y"): T1 F1B1 F1B2 \ \ \ \ socket F1B1 F2B1 F1B2 F2B2 / / / / T2 F2B1 F2B2 And the frames now cannot be rebuilt on the receiving side because the different messages have been mixed up on the wire. The fix wraps these multi-buffer messages into a `FileRegion` object so that these messages are written "atomically" to the next pipeline handler. Author: Marcelo Vanzin Closes #5234 from vanzin/SPARK-6578 and squashes the following commits: 16b2d70 [Marcelo Vanzin] Forgot to update a type. c9c2e4e [Marcelo Vanzin] Review comments: simplify some code. 9c888ac [Marcelo Vanzin] Small style nits. 8474bab [Marcelo Vanzin] Fix multiple calls to MessageWithHeader.transferTo(). e26509f [Marcelo Vanzin] Merge branch 'master' into SPARK-6578 c503f6c [Marcelo Vanzin] Implement a custom FileRegion instead of using locks. 84aa7ce [Marcelo Vanzin] Rename handler to the correct name. 432f3bd [Marcelo Vanzin] Remove unneeded method. 8d70e60 [Marcelo Vanzin] Fix thread-safety issue in outbound path of network library. --- network/common/pom.xml | 5 + .../network/protocol/MessageEncoder.java | 6 +- .../network/protocol/MessageWithHeader.java | 106 ++++++++++++++ .../network/ByteArrayWritableChannel.java | 55 ++++++++ .../apache/spark/network/ProtocolSuite.java | 46 +++++-- .../protocol/MessageWithHeaderSuite.java | 129 ++++++++++++++++++ .../src/test/resources/log4j.properties | 27 ++++ 7 files changed, 364 insertions(+), 10 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java create mode 100644 network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java create mode 100644 network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java create mode 100644 network/common/src/test/resources/log4j.properties diff --git a/network/common/pom.xml b/network/common/pom.xml index 7b51845206f4a..22c738bde6d42 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -80,6 +80,11 @@ mockito-all test + + org.slf4j + slf4j-log4j12 + test + diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 91d1e8a538a77..0f999f5dfe8d8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -72,9 +72,11 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { in.encode(header); assert header.writableBytes() == 0; - out.add(header); if (body != null && bodyLength > 0) { - out.add(body); + out.add(new MessageWithHeader(header, body, bodyLength)); + } else { + out.add(header); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java new file mode 100644 index 0000000000000..215a8517e8608 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Preconditions; +import com.google.common.primitives.Ints; +import io.netty.buffer.ByteBuf; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +/** + * A wrapper message that holds two separate pieces (a header and a body) to avoid + * copying the body's content. + */ +class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + } + + @Override + public long count() { + return headerLength + bodyLength; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return totalBytesTransferred; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + long written = 0; + + if (position < headerLength) { + written += copyByteBuf(header, target); + if (header.readableBytes() > 0) { + totalBytesTransferred += written; + return written; + } + } + + if (body instanceof FileRegion) { + // Adjust the position. If the write is happening as part of the same call where the header + // (or some part of it) is written, `position` will be less than the header size, so we want + // to start from position 0 in the FileRegion object. Otherwise, we start from the position + // requested by the caller. + long bodyPos = position > headerLength ? position - headerLength : 0; + written += ((FileRegion)body).transferTo(target, bodyPos); + } else if (body instanceof ByteBuf) { + written += copyByteBuf((ByteBuf) body, target); + } + + totalBytesTransferred += written; + return written; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + int written = target.write(buf.nioBuffer()); + buf.skipBytes(written); + return written; + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java new file mode 100644 index 0000000000000..b525ed69fc9fb --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +public class ByteArrayWritableChannel implements WritableByteChannel { + + private final byte[] data; + private int offset; + + public ByteArrayWritableChannel(int size) { + this.data = new byte[size]; + this.offset = 0; + } + + public byte[] getData() { + return data; + } + + @Override + public int write(ByteBuffer src) { + int available = src.remaining(); + src.get(data, offset, available); + offset += available; + return available; + } + + @Override + public void close() { + + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 43dc0cf8c7194..860dd6d9b3915 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -17,26 +17,34 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.primitives.Ints; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.MessageToMessageEncoder; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( @@ -51,7 +59,8 @@ private void testServerToClient(Message msg) { } private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( @@ -83,4 +92,25 @@ public void responses() { testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); } + + /** + * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer + * bytes, but messages, so this is needed so that the frame decoder on the receiving side can + * understand what MessageWithHeader actually contains. + */ + private static class FileRegionEncoder extends MessageToMessageEncoder { + + @Override + public void encode(ChannelHandlerContext ctx, FileRegion in, List out) + throws Exception { + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); + while (in.transfered() < in.count()) { + in.transferTo(channel, in.transfered()); + } + out.add(Unpooled.wrappedBuffer(channel.getData())); + } + + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java new file mode 100644 index 0000000000000..ff985096d72d5 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.ByteArrayWritableChannel; + +public class MessageWithHeaderSuite { + + @Test + public void testSingleWrite() throws Exception { + testFileRegionBody(8, 8); + } + + @Test + public void testShortWrite() throws Exception { + testFileRegionBody(8, 1); + } + + @Test + public void testByteBufBody() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ByteBuf body = Unpooled.copyLong(84); + MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + + ByteBuf result = doWrite(msg, 1); + assertEquals(msg.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + assertEquals(84, result.readLong()); + } + + private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { + ByteBuf header = Unpooled.copyLong(42); + int headerLength = header.readableBytes(); + TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); + MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + + ByteBuf result = doWrite(msg, totalWrites / writesPerCall); + assertEquals(headerLength + region.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + for (long i = 0; i < 8; i++) { + assertEquals(i, result.readLong()); + } + } + + private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { + int writes = 0; + ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); + while (msg.transfered() < msg.count()) { + msg.transferTo(channel, msg.transfered()); + writes++; + } + assertTrue("Not enough writes!", minExpectedWrites <= writes); + return Unpooled.wrappedBuffer(channel.getData()); + } + + private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final int writeCount; + private final int writesPerCall; + private int written; + + TestFileRegion(int totalWrites, int writesPerCall) { + this.writeCount = totalWrites; + this.writesPerCall = writesPerCall; + } + + @Override + public long count() { + return 8 * writeCount; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 8 * written; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + for (int i = 0; i < writesPerCall; i++) { + ByteBuf buf = Unpooled.copyLong((position / 8) + i); + ByteBuffer nio = buf.nioBuffer(); + while (nio.remaining() > 0) { + target.write(nio); + } + buf.release(); + written++; + } + return 8 * writesPerCall; + } + + @Override + protected void deallocate() { + } + + } + +} diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e8da774f7ca9e --- /dev/null +++ b/network/common/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO From ccafd757eda478913f783f3127be715bf6413740 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 1 Apr 2015 16:47:18 -0700 Subject: [PATCH 110/171] [SPARK-6642][MLLIB] use 1.2 lambda scaling and remove addImplicit from NormalEquation This PR changes lambda scaling from number of users/items to number of explicit ratings. The latter is the behavior in 1.2. Slight refactor of NormalEquation to make it independent of ALS models. srowen codexiang Author: Xiangrui Meng Closes #5314 from mengxr/SPARK-6642 and squashes the following commits: dc655a1 [Xiangrui Meng] relax python tests f410df2 [Xiangrui Meng] use 1.2 scaling and remove addImplicit from NormalEquation --- .../apache/spark/ml/recommendation/ALS.scala | 67 +++++++++-------- .../spark/ml/recommendation/ALSSuite.scala | 71 +++++++------------ python/pyspark/mllib/recommendation.py | 6 +- 3 files changed, 60 insertions(+), 84 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 514b4ef98dc5b..52c9e95d6012f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -320,7 +320,7 @@ object ALS extends Logging { /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { - /** Solves a least squares problem (possibly with other constraints). */ + /** Solves a least squares problem with regularization (possibly with other constraints). */ def solve(ne: NormalEquation, lambda: Double): Array[Float] } @@ -332,20 +332,19 @@ object ALS extends Logging { /** * Solves a least squares problem with L2 regularization: * - * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * min norm(A x - b)^2^ + lambda * norm(x)^2^ * * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) - * @param lambda regularization constant, which will be scaled by n + * @param lambda regularization constant * @return the solution x */ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val k = ne.k // Add scaled lambda to the diagonals of AtA. - val scaledlambda = lambda * ne.n var i = 0 var j = 2 while (i < ne.triK) { - ne.ata(i) += scaledlambda + ne.ata(i) += lambda i += j j += 1 } @@ -391,7 +390,7 @@ object ALS extends Logging { override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val rank = ne.k initialize(rank) - fillAtA(ne.ata, lambda * ne.n) + fillAtA(ne.ata, lambda) val x = NNLS.solve(ata, ne.atb, workspace) ne.reset() x.map(x => x.toFloat) @@ -420,7 +419,15 @@ object ALS extends Logging { } } - /** Representing a normal equation (ALS' subproblem). */ + /** + * Representing a normal equation to solve the following weighted least squares problem: + * + * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. + * + * Its normal equation is given by + * + * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. + */ private[recommendation] class NormalEquation(val k: Int) extends Serializable { /** Number of entries in the upper triangular part of a k-by-k matrix. */ @@ -429,8 +436,6 @@ object ALS extends Logging { val ata = new Array[Double](triK) /** A^T^ * b */ val atb = new Array[Double](k) - /** Number of observations. */ - var n = 0 private val da = new Array[Double](k) private val upper = "U" @@ -444,28 +449,13 @@ object ALS extends Logging { } /** Adds an observation. */ - def add(a: Array[Float], b: Float): this.type = { - require(a.length == k) - copyToDouble(a) - blas.dspr(upper, k, 1.0, da, 1, ata) - blas.daxpy(k, b.toDouble, da, 1, atb, 1) - n += 1 - this - } - - /** - * Adds an observation with implicit feedback. Note that this does not increment the counter. - */ - def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { + require(c >= 0.0) require(a.length == k) - // Extension to the original paper to handle b < 0. confidence is a function of |b| instead - // so that it is never negative. - val confidence = 1.0 + alpha * math.abs(b) copyToDouble(a) - blas.dspr(upper, k, confidence - 1.0, da, 1, ata) - // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. - if (b > 0) { - blas.daxpy(k, confidence, da, 1, atb, 1) + blas.dspr(upper, k, c, da, 1, ata) + if (b != 0.0) { + blas.daxpy(k, c * b, da, 1, atb, 1) } this } @@ -475,7 +465,6 @@ object ALS extends Logging { require(other.k == k) blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) - n += other.n this } @@ -483,7 +472,6 @@ object ALS extends Logging { def reset(): Unit = { ju.Arrays.fill(ata, 0.0) ju.Arrays.fill(atb, 0.0) - n = 0 } } @@ -1114,6 +1102,7 @@ object ALS extends Logging { ls.merge(YtY.get) } var i = srcPtrs(j) + var numExplicits = 0 while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -1121,13 +1110,23 @@ object ALS extends Logging { val srcFactor = sortedSrcFactors(blockId)(localIndex) val rating = ratings(i) if (implicitPrefs) { - ls.addImplicit(srcFactor, rating, alpha) + // Extension to the original paper to handle b < 0. confidence is a function of |b| + // instead so that it is never negative. c1 is confidence - 1.0. + val c1 = alpha * math.abs(rating) + // For rating <= 0, the corresponding preference is 0. So the term below is only added + // for rating > 0. Because YtY is already added, we need to adjust the scaling here. + if (rating > 0) { + numExplicits += 1 + ls.add(srcFactor, (c1 + 1.0) / c1, c1) + } } else { ls.add(srcFactor, rating) + numExplicits += 1 } i += 1 } - dstFactors(j) = solver.solve(ls, regParam) + // Weight lambda by the number of explicit ratings based on the ALS-WR paper. + dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 } dstFactors @@ -1141,7 +1140,7 @@ object ALS extends Logging { private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { factorBlocks.values.aggregate(new NormalEquation(rank))( seqOp = (ne, factors) => { - factors.foreach(ne.add(_, 0.0f)) + factors.foreach(ne.add(_, 0.0)) ne }, combOp = (ne1, ne2) => ne1.merge(ne2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 0bb06e9e8ac9c..29d4ec5f85c1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -68,39 +68,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } } - test("normal equation construction with explict feedback") { + test("normal equation construction") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 3.0f) - .add(Array(4.0f, 5.0f), 6.0f) + .add(Array(1.0f, 2.0f), 3.0) + .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted assert(ne0.k === k) assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 2) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5") // b = np.matrix("3; 6") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8) val ne1 = new NormalEquation(2) - .add(Array(7.0f, 8.0f), 9.0f) + .add(Array(7.0f, 8.0f), 9.0) ne0.merge(ne1) - assert(ne0.n === 3) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5; 7 8") // b = np.matrix("3; 6; 9") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2, 1])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8) intercept[IllegalArgumentException] { - ne0.add(Array(1.0f), 2.0f) + ne0.add(Array(1.0f), 2.0) } intercept[IllegalArgumentException] { - ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f), 0.0, -1.0) } intercept[IllegalArgumentException] { val ne2 = new NormalEquation(3) @@ -108,41 +111,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } ne0.reset() - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) } - test("normal equation construction with implicit feedback") { - val k = 2 - val alpha = 0.5 - val ne0 = new NormalEquation(k) - .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) - .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) - .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) - assert(ne0.k === k) - assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 0) // addImplicit doesn't increase the count. - // NumPy code that computes the expected values: - // alpha = 0.5 - // A = np.matrix("-5 -4; -2 -1; 1 2") - // b = np.matrix("-3; 0; 3") - // b1 = b > 0 - // c = 1.0 + alpha * np.abs(b) - // C = np.diag(c.A1) - // I = np.eye(3) - // ata = A.transpose() * (C - I) * A - // atb = A.transpose() * C * b1 - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) - } - test("CholeskySolver") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 4.0f) - .add(Array(1.0f, 3.0f), 9.0f) - .add(Array(1.0f, 4.0f), 16.0f) + .add(Array(1.0f, 2.0f), 4.0) + .add(Array(1.0f, 3.0f), 9.0) + .add(Array(1.0f, 4.0f), 16.0) val ne1 = new NormalEquation(k) .merge(ne0) @@ -154,13 +132,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { // x0 = np.linalg.lstsq(A, b)[0] assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) - val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + val x1 = chol.solve(ne1, 1.5).map(_.toDouble) // NumPy code that computes the expected solution, where lambda is scaled by n: - // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b) assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) } diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index b094e50856f70..c5c4c13dae105 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) >>> model.predict(2, 2) - 0.43... + 0.4... >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 2, seed=0) @@ -82,14 +82,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 0.43... + 0.4... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = MatrixFactorizationModel.load(sc, path) >>> sameModel.predict(2,2) - 0.43... + 0.4... >>> sameModel.predictAll(testset).collect() [Rating(... >>> try: From 2fa3b47dbf38aae58514473932c69bbd35de4e4c Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 1 Apr 2015 17:03:39 -0700 Subject: [PATCH 111/171] [SPARK-6576] [MLlib] [PySpark] DenseMatrix in PySpark should support indexing Support indexing in DenseMatrices in PySpark Author: MechCoder Closes #5232 from MechCoder/SPARK-6576 and squashes the following commits: a735078 [MechCoder] Change bounds a062025 [MechCoder] Matrices are stored in column order 7917bc1 [MechCoder] [SPARK-6576] DenseMatrix in PySpark should support indexing --- python/pyspark/mllib/linalg.py | 10 ++++++++++ python/pyspark/mllib/tests.py | 7 +++++++ 2 files changed, 17 insertions(+) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 8b791ff6a7877..51c1490b1618d 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -670,6 +670,16 @@ def toArray(self): """ return self.values.reshape((self.numRows, self.numCols), order='F') + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + return self.values[i + j * self.numRows] + def __eq__(self, other): return (isinstance(other, DenseMatrix) and self.numRows == other.numRows and diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3bb0f0ca68128..893fc6f491ab3 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -135,6 +135,13 @@ def test_sparse_vector_indexing(self): for ind in [4, -5, 7.8]: self.assertRaises(ValueError, sv.__getitem__, ind) + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEquals(mat[i, j], expected[i][j]) + class ListTests(PySparkTestCase): From 86b43993517104e6d5ad0785704ceec6db8acc20 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 1 Apr 2015 17:19:36 -0700 Subject: [PATCH 112/171] [SPARK-6580] [MLLIB] Optimize LogisticRegressionModel.predictPoint https://issues.apache.org/jira/browse/SPARK-6580 Author: Yanbo Liang Closes #5249 from yanboliang/spark-6580 and squashes the following commits: 6f47f21 [Yanbo Liang] address comments 4e0bd0f [Yanbo Liang] fix typos 04e2e2a [Yanbo Liang] trigger jenkins cad5bcd [Yanbo Liang] Optimize LogisticRegressionModel.predictPoint --- .../classification/LogisticRegression.scala | 55 +++++++++---------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index e7c3599ff619c..057b628c6a586 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -62,6 +62,15 @@ class LogisticRegressionModel ( s" but was given weights of length ${weights.size}") } + private val dataWithBiasSize: Int = weights.size / (numClasses - 1) + + private val weightsArray: Array[Double] = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + /** * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. */ @@ -74,6 +83,7 @@ class LogisticRegressionModel ( * Sets the threshold that separates positive predictions from negative predictions * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. + * It is only used for binary classification. */ @Experimental def setThreshold(threshold: Double): this.type = { @@ -84,6 +94,7 @@ class LogisticRegressionModel ( /** * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + * It is only used for binary classification. */ @Experimental def getThreshold: Option[Double] = threshold @@ -91,6 +102,7 @@ class LogisticRegressionModel ( /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. + * It is only used for binary classification. */ @Experimental def clearThreshold(): this.type = { @@ -106,7 +118,6 @@ class LogisticRegressionModel ( // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. if (numClasses == 2) { - require(numFeatures == weightMatrix.size) val margin = dot(weightMatrix, dataMatrix) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { @@ -114,30 +125,9 @@ class LogisticRegressionModel ( case None => score } } else { - val dataWithBiasSize = weightMatrix.size / (numClasses - 1) - - val weightsArray = weightMatrix match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weightMatrix.getClass}.") - } - - val margins = (0 until numClasses - 1).map { i => - var margin = 0.0 - dataMatrix.foreachActive { (index, value) => - if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) - } - // Intercept is required to be added into margin. - if (dataMatrix.size + 1 == dataWithBiasSize) { - margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) - } - margin - } - /** - * Find the one with maximum margins. If the maxMargin is negative, then the prediction - * result will be the first class. + * Compute and find the one with maximum margins. If the maxMargin is negative, then the + * prediction result will be the first class. * * PS, if you want to compute the probabilities for each outcome instead of the outcome * with maximum probability, remember to subtract the maxMargin from margins if maxMargin @@ -145,13 +135,20 @@ class LogisticRegressionModel ( */ var bestClass = 0 var maxMargin = 0.0 - var i = 0 - while(i < margins.size) { - if (margins(i) > maxMargin) { - maxMargin = margins(i) + val withBias = dataMatrix.size + 1 == dataWithBiasSize + (0 until numClasses - 1).foreach { i => + var margin = 0.0 + dataMatrix.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) + } + // Intercept is required to be added into margin. + if (withBias) { + margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) + } + if (margin > maxMargin) { + maxMargin = margin bestClass = i + 1 } - i += 1 } bestClass.toDouble } From 757b2e91756ba49d7d1ab89abf19b00c7f5fd721 Mon Sep 17 00:00:00 2001 From: ksonj Date: Wed, 1 Apr 2015 17:23:57 -0700 Subject: [PATCH 113/171] [SPARK-6553] [pyspark] Support functools.partial as UDF Use `f.__repr__()` instead of `f.__name__` when instantiating `UserDefinedFunction`s, so `functools.partial`s may be used. Author: ksonj Closes #5206 from ksonj/partials and squashes the following commits: ea66f3d [ksonj] Inserted blank lines for PEP8 compliance d81b02b [ksonj] added tests for udf with partial function and callable object 2c76100 [ksonj] Makes UDFs work with all types of callables b814a12 [ksonj] support functools.partial as udf (cherry picked from commit 98f72dfc17853b570d05c20e97c78919682b6df6) Signed-off-by: Josh Rosen --- python/pyspark/sql/functions.py | 3 ++- python/pyspark/sql/tests.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8a478fddf0e95..146ba6f3e0d98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -123,7 +123,8 @@ def _create_judf(self): pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) return judf diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 258464b7f230d..b3a6a2c6a9229 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -25,6 +25,7 @@ import shutil import tempfile import pickle +import functools import py4j @@ -41,6 +42,7 @@ from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction class ExamplePointUDT(UserDefinedType): @@ -114,6 +116,35 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() From 4815bc2128c7f6d4d21da730b8c72da087233b34 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 1 Apr 2015 18:17:07 -0700 Subject: [PATCH 114/171] [SPARK-6660][MLLIB] pythonToJava doesn't recognize object arrays davies Author: Xiangrui Meng Closes #5318 from mengxr/SPARK-6660 and squashes the following commits: 0f66ec2 [Xiangrui Meng] recognize object arrays ad8c42f [Xiangrui Meng] add a test for SPARK-6660 --- .../apache/spark/mllib/api/python/PythonMLLibAPI.scala | 5 ++++- python/pyspark/mllib/tests.py | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 662ec5fbed453..5995d6df97c15 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1113,7 +1113,10 @@ private[spark] object SerDe extends Serializable { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } } else { Seq(obj) } diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 893fc6f491ab3..6e9c68ec8a5c1 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -36,6 +36,7 @@ else: import unittest +from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, Vectors, Matrices from pyspark.mllib.regression import LabeledPoint @@ -641,6 +642,13 @@ def test_idf_model(self): idf = model.idf() self.assertEqual(len(idf), 11) + +class SerDeTest(PySparkTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" From 899ebcb1448126f40be784ce42e69218e9a1ead7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Apr 2015 18:36:06 -0700 Subject: [PATCH 115/171] [SPARK-6578] Small rewrite to make the logic more clear in MessageWithHeader.transferTo. Author: Reynold Xin Closes #5319 from rxin/SPARK-6578 and squashes the following commits: 7c62a64 [Reynold Xin] Small rewrite to make the logic more clear in transferTo. --- .../network/protocol/MessageWithHeader.java | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index 215a8517e8608..d686a951467cf 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -21,15 +21,15 @@ import java.nio.channels.WritableByteChannel; import com.google.common.base.Preconditions; -import com.google.common.primitives.Ints; import io.netty.buffer.ByteBuf; import io.netty.channel.FileRegion; import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; /** - * A wrapper message that holds two separate pieces (a header and a body) to avoid - * copying the body's content. + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. */ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { @@ -63,32 +63,36 @@ public long transfered() { return totalBytesTransferred; } + /** + * This code is more complicated than you would think because we might require multiple + * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. + * + * The contract is that the caller will ensure position is properly set to the total number + * of bytes transferred so far (i.e. value returned by transfered()). + */ @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { + public long transferTo(final WritableByteChannel target, final long position) throws IOException { Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); - long written = 0; - - if (position < headerLength) { - written += copyByteBuf(header, target); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; if (header.readableBytes() > 0) { - totalBytesTransferred += written; - return written; + return writtenHeader; } } + // Bytes written for body in this call. + long writtenBody = 0; if (body instanceof FileRegion) { - // Adjust the position. If the write is happening as part of the same call where the header - // (or some part of it) is written, `position` will be less than the header size, so we want - // to start from position 0 in the FileRegion object. Otherwise, we start from the position - // requested by the caller. - long bodyPos = position > headerLength ? position - headerLength : 0; - written += ((FileRegion)body).transferTo(target, bodyPos); + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); } else if (body instanceof ByteBuf) { - written += copyByteBuf((ByteBuf) body, target); + writtenBody = copyByteBuf((ByteBuf) body, target); } + totalBytesTransferred += writtenBody; - totalBytesTransferred += written; - return written; + return writtenHeader + writtenBody; } @Override @@ -102,5 +106,4 @@ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOExcept buf.skipBytes(written); return written; } - } From 191524e7401fcdfae46dc7e6a64c28907b1b1c20 Mon Sep 17 00:00:00 2001 From: Chet Mancini Date: Wed, 1 Apr 2015 21:39:46 -0700 Subject: [PATCH 116/171] [SPARK-6658][SQL] Update DataFrame documentation to fix type references. First contribution here; would love to be getting some code contributions in soon. Let me know if there's anything about contribution process I should improve. Author: Chet Mancini Closes #5316 from chetmancini/SPARK_6658_dataframe_doc and squashes the following commits: 53b627a [Chet Mancini] [SQL] SPARK-6658: Update DataFrame documentation to refer to correct types --- .../main/scala/org/apache/spark/sql/DataFrame.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 19cfa15f27b09..ce0890906bf1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -273,7 +273,7 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) /** - * Prints the plans (logical and physical) to the console for debugging purpose. + * Prints the plans (logical and physical) to the console for debugging purposes. * @group basic */ def explain(extended: Boolean): Unit = { @@ -285,7 +285,7 @@ class DataFrame private[sql]( } /** - * Only prints the physical plan to the console for debugging purpose. + * Only prints the physical plan to the console for debugging purposes. * @group basic */ def explain(): Unit = explain(extended = false) @@ -976,8 +976,8 @@ class DataFrame private[sql]( def javaRDD: JavaRDD[Row] = toJavaRDD /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SQLContext]] that was used to create this DataFrame. * * @group basic */ @@ -1252,7 +1252,7 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * Save this RDD to a JDBC database at `url` under the table name `table`. + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. * If you pass `true` for `allowExisting`, it will drop any table with the * given name; if you pass `false`, it will throw if the table already @@ -1276,7 +1276,7 @@ class DataFrame private[sql]( } /** - * Save this RDD to a JDBC database at `url` under the table name `table`. + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. * Assumes the table already exists and has a compatible schema. If you * pass `true` for `overwrite`, it will `TRUNCATE` the table before * performing the `INSERT`s. From 2bc7fe7f7eb31b8f0591611b1e66b601bba8a4b7 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 2 Apr 2015 12:56:34 +0800 Subject: [PATCH 117/171] Revert "[SPARK-6618][SQL] HiveMetastoreCatalog.lookupRelation should use fine-grained lock" This reverts commit 314afd0e2f08dd8d3333d3143712c2c79fa40d1e. --- .../apache/spark/sql/hive/HiveMetastoreCatalog.scala | 12 +++--------- .../spark/sql/hive/execution/SQLQuerySuite.scala | 11 ----------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2b5d031741a63..f0076cef13777 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -173,16 +173,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def lookupRelation( tableIdentifier: Seq[String], - alias: Option[String]): LogicalPlan = { + alias: Option[String]): LogicalPlan = synchronized { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - val table = try { - synchronized { - client.getTable(databaseName, tblName) - } - } catch { + val table = try client.getTable(databaseName, tblName) catch { case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException => throw new NoSuchTableException } @@ -204,9 +200,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } else { val partitions: Seq[Partition] = if (table.isPartitioned) { - synchronized { - HiveShim.getAllPartitionsOf(client, table).toSeq - } + HiveShim.getAllPartitionsOf(client, table).toSeq } else { Nil } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2065f0d60d92f..310c2bfdf1011 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -457,15 +457,4 @@ class SQLQuerySuite extends QueryTest { dropTempTable("data") setConf("spark.sql.hive.convertCTAS", originalConf) } - - test("sanity test for SPARK-6618") { - (1 to 100).par.map { i => - val tableName = s"SPARK_6618_table_$i" - sql(s"CREATE TABLE $tableName (col1 string)") - catalog.lookupRelation(Seq(tableName)) - table(tableName) - tables() - sql(s"DROP TABLE $tableName") - } - } } From 40df5d49bb5c80cd3a1e2d7c853c0b5ea901adf3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Apr 2015 23:11:38 -0700 Subject: [PATCH 118/171] [SPARK-6663] [SQL] use Literal.create instread of constructor In order to do inbound checking and type conversion, we should use Literal.create() instead of constructor. Author: Davies Liu Closes #5320 from davies/literal and squashes the following commits: 1667604 [Davies Liu] fix style and add comment 5f8c0fd [Davies Liu] use Literal.create instread of constructor --- .../apache/spark/sql/catalyst/SqlParser.scala | 8 +- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 3 +- .../sql/catalyst/expressions/literals.scala | 7 +- .../sql/catalyst/optimizer/Optimizer.scala | 42 ++-- .../analysis/HiveTypeCoercionSuite.scala | 4 +- .../ExpressionEvaluationSuite.scala | 204 +++++++++--------- .../optimizer/ConstantFoldingSuite.scala | 70 +++--- .../sql/catalyst/trees/TreeNodeSuite.scala | 2 +- .../sql/execution/GeneratedAggregate.scala | 8 +- .../apache/spark/sql/parquet/newParquet.scala | 21 +- .../ParquetPartitionDiscoverySuite.scala | 20 +- .../spark/sql/hive/HiveInspectors.scala | 2 +- .../org/apache/spark/sql/hive/HiveQl.scala | 20 +- .../spark/sql/hive/HiveInspectorSuite.scala | 16 +- 16 files changed, 220 insertions(+), 213 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b176f7e729a42..89f4a19add1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -316,13 +316,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal(s, StringType) } - | NULL ^^^ Literal(null, NullType) + | stringLit ^^ {case s => Literal.create(s, StringType) } + | NULL ^^^ Literal.create(null, NullType) ) protected lazy val booleanLiteral: Parser[Literal] = - ( TRUE ^^^ Literal(true, BooleanType) - | FALSE ^^^ Literal(false, BooleanType) + ( TRUE ^^^ Literal.create(true, BooleanType) + | FALSE ^^^ Literal.create(false, BooleanType) ) protected lazy val numericLiteral: Parser[Literal] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c578d084a45b6..119cb9c3a4400 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -140,10 +140,10 @@ class Analyzer( case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null - Literal(null, expr.dataType) + Literal.create(null, expr.dataType) case x if x == g.gid => // replace the groupingId with concrete value (the bit mask) - Literal(bitmask, IntegerType) + Literal.create(bitmask, IntegerType) }) result += GroupExpression(substitution) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 3c7b46e0702a2..9a33eb145273e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -115,7 +115,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN", StringType) + val stringNaN = Literal.create("NaN", StringType) def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 30da4faa3f1c6..406de38d1c483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -505,7 +505,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private var count: Long = _ private val sum = MutableLiteral(zero.eval(null), calcType) - private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType)) + private def addFunction(value: Any) = Add(sum, + Cast(Literal.create(value, expr.dataType), calcType)) override def eval(input: Row): Any = { if (count == 0L) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 19f3fc9c2291a..0e2d593e94124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -41,6 +41,8 @@ object Literal { case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + + def create(v: Any, dataType: DataType): Literal = Literal(v, dataType) } /** @@ -62,7 +64,10 @@ object IntegerLiteral { } } -case class Literal(value: Any, dataType: DataType) extends LeafExpression { +/** + * In order to do type checking, use Literal.create() instead of constructor + */ +case class Literal protected (value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = true override def nullable: Boolean = value == null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c23d3b61887c6..93e69d409cb91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -218,12 +218,12 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) - case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) - case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) - case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType) + case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) @@ -235,36 +235,36 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => true } if (newChildren.length == 0) { - Literal(null, e.dataType) + Literal.create(null, e.dataType) } else if (newChildren.length == 1) { newChildren(0) } else { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: BinaryComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } } @@ -284,13 +284,13 @@ object ConstantFolding extends Rule[LogicalPlan] { case l: Literal => l // Fold expressions that are foldable. - case e if e.foldable => Literal(e.eval(null), e.dataType) + case e if e.foldable => Literal.create(e.eval(null), e.dataType) // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. case In(Literal(v, _), list) if list.exists { case Literal(candidate, _) if candidate == v => true case _ => false - } => Literal(true, BooleanType) + } => Literal.create(true, BooleanType) } } } @@ -647,7 +647,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => Cast( - Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index ecbb54218d457..70aef1cac421a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -127,11 +127,11 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest( Coalesce(Literal(1.0) :: Literal(1) - :: Literal(1.0, FloatType) + :: Literal.create(1.0, FloatType) :: Nil), Coalesce(Cast(Literal(1.0), DoubleType) :: Cast(Literal(1), DoubleType) - :: Cast(Literal(1.0, FloatType), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) ruleTest( Coalesce(Literal(1L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 1183a0d899dda..3dbefa40d2808 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -111,7 +111,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - checkEvaluation(!Literal(v, BooleanType), answer) + checkEvaluation(!Literal.create(v, BooleanType), answer) } } @@ -155,7 +155,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test(s"3VL $name") { truthTable.foreach { case (l,r,answer) => - val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) + val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) checkEvaluation(expr, answer) } } @@ -175,12 +175,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Divide(Literal(1), Literal(0)), null) checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("Remainder") { @@ -190,12 +190,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Remainder(Literal(1), Literal(0)), null) checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("INSET") { @@ -222,14 +222,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(MaxOf(1L, 2L), 2L) checkEvaluation(MaxOf(2L, 1L), 2L) - checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2) + checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) } test("LIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType).like("a"), null) - checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) - checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -264,13 +264,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) + checkEvaluation(Literal.create(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) } test("RLIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal(null, StringType), null) - checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) @@ -381,7 +381,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) - checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) + checkEvaluation(Cast(Literal.create(null, IntegerType), ShortType), null) } test("date") { @@ -507,8 +507,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("array casting") { - val array = Literal(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + val array = Literal.create(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) { val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) @@ -556,10 +556,10 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("map casting") { - val map = Literal( + val map = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) - val map_notNull = Literal( + val map_notNull = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) @@ -617,14 +617,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("struct casting") { - val struct = Literal( + val struct = Literal.create( Row("123", "abc", "", null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) - val struct_notNull = Literal( + val struct_notNull = Literal.create( Row("123", "abc", ""), StructType(Seq( StructField("a", StringType, nullable = false), @@ -712,7 +712,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("complex casting") { - val complex = Literal( + val complex = Literal.create( Row( Seq("123", "abc", ""), Map("a" -> "123", "b" -> "abc", "c" -> ""), @@ -755,30 +755,30 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c2.isNull, true, row) checkEvaluation(c2.isNotNull, false, row) - checkEvaluation(Literal(1, ShortType).isNull, false) - checkEvaluation(Literal(1, ShortType).isNotNull, true) + checkEvaluation(Literal.create(1, ShortType).isNull, false) + checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - checkEvaluation(Literal(null, ShortType).isNull, true) - checkEvaluation(Literal(null, ShortType).isNotNull, false) + checkEvaluation(Literal.create(null, ShortType).isNull, true) + checkEvaluation(Literal.create(null, ShortType).isNotNull, false) checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row) + checkEvaluation(If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) checkEvaluation(If(c3, c1, c2), "^Ba*n", row) checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), - Literal("a", StringType), Literal("b", StringType)), "b", row) + checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), + Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) checkEvaluation(c1 in (c1, c2), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) } test("case when") { @@ -793,9 +793,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) @@ -841,17 +841,17 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation(GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(3, typeMap, true), - Literal(null, StringType)), null, row) + Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), - Literal(null, IntegerType)), null, row) + Literal.create(null, IntegerType)), null, row) def quickBuildGetField(expr: Expression, fieldName: String) = { expr.dataType match { @@ -864,7 +864,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName) checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(quickBuildGetField(Literal(null, typeS), "a"), null, row) + checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row) val typeS_notNullable = StructType( StructField("a", StringType, nullable = false) @@ -874,8 +874,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) - assert(quickBuildGetField(Literal(null, typeS), "a").nullable === true) - assert(quickBuildGetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) @@ -890,13 +890,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val c4 = 'a.int.at(3) checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100) + checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(Add(Literal(null, IntegerType), c2), null, row) - checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(-c1, -1, row) checkEvaluation(c1 + c2, 3, row) @@ -914,12 +914,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val c4 = 'a.double.at(3) checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal(100.0, DoubleType)), -100.0) + checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal(null, DoubleType)), null, row) - checkEvaluation(Add(Literal(null, DoubleType), c2), null, row) - checkEvaluation(Add(Literal(null, DoubleType), Literal(null, DoubleType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) checkEvaluation(-c1, -1.1, row) checkEvaluation(c1 + c2, 3.1, row) @@ -940,9 +940,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(LessThan(c1, c4), null, row) checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 < c2, true, row) checkEvaluation(c1 <= c2, true, row) @@ -954,8 +954,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 <=> c4, false, row) checkEvaluation(c4 <=> c6, true, row) checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal(null, BooleanType), false, row) - checkEvaluation(Literal(null, BooleanType) <=> Literal(true), false, row) + checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) + checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) } test("StringComparison") { @@ -966,17 +966,17 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 contains "b", true, row) checkEvaluation(c1 contains "x", false, row) checkEvaluation(c2 contains "b", null, row) - checkEvaluation(c1 contains Literal(null, StringType), null, row) + checkEvaluation(c1 contains Literal.create(null, StringType), null, row) checkEvaluation(c1 startsWith "a", true, row) checkEvaluation(c1 startsWith "b", false, row) checkEvaluation(c2 startsWith "a", null, row) - checkEvaluation(c1 startsWith Literal(null, StringType), null, row) + checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) checkEvaluation(c1 endsWith "c", true, row) checkEvaluation(c1 endsWith "b", false, row) checkEvaluation(c2 endsWith "b", null, row) - checkEvaluation(c1 endsWith Literal(null, StringType), null, row) + checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) } test("Substring") { @@ -985,54 +985,54 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val s = 'a.string.at(0) // substring from zero position with less-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) // substring from zero position with full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), "example", row) // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), "xa", row) // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), "xample", row) // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), "xample", row) // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), "", row) // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), "", row) // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), null, new GenericRow(Array[Any](null))) // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row) + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), null, row) // substring(_, _, null) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), null, row) // 2-arg substring from zero position - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "example", row) // 2-arg substring from nonzero position - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "xample", row) val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable === false) + assert(Substring(s_notNull, Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) @@ -1050,7 +1050,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null))) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } @@ -1064,22 +1064,22 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(BitwiseAnd(c1, c4), null, row) checkEvaluation(BitwiseAnd(c1, c2), 0, row) - checkEvaluation(BitwiseAnd(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseAnd(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseOr(c1, c4), null, row) checkEvaluation(BitwiseOr(c1, c2), 3, row) - checkEvaluation(BitwiseOr(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseOr(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseXor(c1, c4), null, row) checkEvaluation(BitwiseXor(c1, c2), 3, row) - checkEvaluation(BitwiseXor(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseXor(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseNot(c4), null, row) checkEvaluation(BitwiseNot(c1), -2, row) - checkEvaluation(BitwiseNot(Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 & c2, 0, row) checkEvaluation(c1 | c2, 3, row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ef10c0aece716..a0efe9e2e7f6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -182,33 +182,33 @@ class ConstantFoldingSuite extends PlanTest { IsNull(Literal(null)) as 'c1, IsNotNull(Literal(null)) as 'c2, - GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, - GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, + GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem(Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, UnresolvedGetField( - Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), + Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), "a") as 'c5, - UnaryMinus(Literal(null, IntegerType)) as 'c6, + UnaryMinus(Literal.create(null, IntegerType)) as 'c6, Cast(Literal(null), IntegerType) as 'c7, - Not(Literal(null, BooleanType)) as 'c8, + Not(Literal.create(null, BooleanType)) as 'c8, - Add(Literal(null, IntegerType), 1) as 'c9, - Add(1, Literal(null, IntegerType)) as 'c10, + Add(Literal.create(null, IntegerType), 1) as 'c9, + Add(1, Literal.create(null, IntegerType)) as 'c10, - EqualTo(Literal(null, IntegerType), 1) as 'c11, - EqualTo(1, Literal(null, IntegerType)) as 'c12, + EqualTo(Literal.create(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal.create(null, IntegerType)) as 'c12, - Like(Literal(null, StringType), "abc") as 'c13, - Like("abc", Literal(null, StringType)) as 'c14, + Like(Literal.create(null, StringType), "abc") as 'c13, + Like("abc", Literal.create(null, StringType)) as 'c14, - Upper(Literal(null, StringType)) as 'c15, + Upper(Literal.create(null, StringType)) as 'c15, - Substring(Literal(null, StringType), 0, 1) as 'c16, - Substring("abc", Literal(null, IntegerType), 1) as 'c17, - Substring("abc", 0, Literal(null, IntegerType)) as 'c18, + Substring(Literal.create(null, StringType), 0, 1) as 'c16, + Substring("abc", Literal.create(null, IntegerType), 1) as 'c17, + Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18, - Contains(Literal(null, StringType), "abc") as 'c19, - Contains("abc", Literal(null, StringType)) as 'c20 + Contains(Literal.create(null, StringType), "abc") as 'c19, + Contains("abc", Literal.create(null, StringType)) as 'c20 ) val optimized = Optimize(originalQuery.analyze) @@ -219,31 +219,31 @@ class ConstantFoldingSuite extends PlanTest { Literal(true) as 'c1, Literal(false) as 'c2, - Literal(null, IntegerType) as 'c3, - Literal(null, IntegerType) as 'c4, - Literal(null, IntegerType) as 'c5, + Literal.create(null, IntegerType) as 'c3, + Literal.create(null, IntegerType) as 'c4, + Literal.create(null, IntegerType) as 'c5, - Literal(null, IntegerType) as 'c6, - Literal(null, IntegerType) as 'c7, - Literal(null, BooleanType) as 'c8, + Literal.create(null, IntegerType) as 'c6, + Literal.create(null, IntegerType) as 'c7, + Literal.create(null, BooleanType) as 'c8, - Literal(null, IntegerType) as 'c9, - Literal(null, IntegerType) as 'c10, + Literal.create(null, IntegerType) as 'c9, + Literal.create(null, IntegerType) as 'c10, - Literal(null, BooleanType) as 'c11, - Literal(null, BooleanType) as 'c12, + Literal.create(null, BooleanType) as 'c11, + Literal.create(null, BooleanType) as 'c12, - Literal(null, BooleanType) as 'c13, - Literal(null, BooleanType) as 'c14, + Literal.create(null, BooleanType) as 'c13, + Literal.create(null, BooleanType) as 'c14, - Literal(null, StringType) as 'c15, + Literal.create(null, StringType) as 'c15, - Literal(null, StringType) as 'c16, - Literal(null, StringType) as 'c17, - Literal(null, StringType) as 'c18, + Literal.create(null, StringType) as 'c16, + Literal.create(null, StringType) as 'c17, + Literal.create(null, StringType) as 'c18, - Literal(null, BooleanType) as 'c19, - Literal(null, BooleanType) as 'c20 + Literal.create(null, BooleanType) as 'c19, + Literal.create(null, BooleanType) as 'c20 ).analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e7ce92a2160b6..274f3ede0045c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -90,7 +90,7 @@ class TreeNodeSuite extends FunSuite { } test("transform works on nodes with Option children") { - val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 89682d25ca7dc..a8018b9213f2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -93,7 +93,7 @@ case class GeneratedAggregate( } val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal(null, calcType) + val initialValue = Literal.create(null, calcType) // Coalasce avoids double calculation... // but really, common sub expression elimination would be better.... @@ -137,13 +137,13 @@ case class GeneratedAggregate( expr.dataType match { case DecimalType.Fixed(_, _) => If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), + Literal.create(null, a.dataType), Cast(Divide( Cast(currentSum, DecimalType.Unlimited), Cast(currentCount, DecimalType.Unlimited)), a.dataType)) case _ => If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), + Literal.create(null, a.dataType), Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType))) } @@ -156,7 +156,7 @@ case class GeneratedAggregate( case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal(null, expr.dataType) + val initialValue = Literal.create(null, expr.dataType) val updateMax = MaxOf(currentMax, expr) AggregateEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 19800ad88c031..43f260d3ef8d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -872,9 +872,9 @@ private[sql] object ParquetRelation2 extends Logging { * PartitionValues( * Seq("a", "b", "c"), * Seq( - * Literal(42, IntegerType), - * Literal("hello", StringType), - * Literal(3.14, FloatType))) + * Literal.create(42, IntegerType), + * Literal.create("hello", StringType), + * Literal.create(3.14, FloatType))) * }}} */ private[parquet] def parsePartition( @@ -953,15 +953,16 @@ private[sql] object ParquetRelation2 extends Logging { raw: String, defaultPartitionName: String): Literal = { // First tries integral types - Try(Literal(Integer.parseInt(raw), IntegerType)) - .orElse(Try(Literal(JLong.parseLong(raw), LongType))) + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) // Then falls back to fractional types - .orElse(Try(Literal(JFloat.parseFloat(raw), FloatType))) - .orElse(Try(Literal(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) + .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) + if (raw == defaultPartitionName) Literal.create(null, NullType) + else Literal.create(raw, StringType) } } @@ -980,7 +981,7 @@ private[sql] object ParquetRelation2 extends Logging { } literals.map { case l @ Literal(_, dataType) => - Literal(Cast(l, desiredType).eval(), desiredType) + Literal.create(Cast(l, desiredType).eval(), desiredType) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index adb3c9391f6c2..b7561ce7298cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -45,11 +45,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) } - check("10", Literal(10, IntegerType)) - check("1000000000000000", Literal(1000000000000000L, LongType)) - check("1.5", Literal(1.5, FloatType)) - check("hello", Literal("hello", StringType)) - check(defaultPartitionName, Literal(null, NullType)) + check("10", Literal.create(10, IntegerType)) + check("1000000000000000", Literal.create(1000000000000000L, LongType)) + check("1.5", Literal.create(1.5, FloatType)) + check("hello", Literal.create("hello", StringType)) + check(defaultPartitionName, Literal.create(null, NullType)) } test("parse partition") { @@ -75,22 +75,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "file://path/a=10", PartitionValues( ArrayBuffer("a"), - ArrayBuffer(Literal(10, IntegerType)))) + ArrayBuffer(Literal.create(10, IntegerType)))) check( "file://path/a=10/b=hello/c=1.5", PartitionValues( ArrayBuffer("a", "b", "c"), ArrayBuffer( - Literal(10, IntegerType), - Literal("hello", StringType), - Literal(1.5, FloatType)))) + Literal.create(10, IntegerType), + Literal.create("hello", StringType), + Literal.create(1.5, FloatType)))) check( "file://path/a=10/b_hello/c=1.5", PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal(1.5, FloatType)))) + ArrayBuffer(Literal.create(1.5, FloatType)))) checkThrows[AssertionError]("file://path/=10", "Empty partition column name") checkThrows[AssertionError]("file://path/a=", "Empty partition column value") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4afa2e71d77cc..921c6194c7b76 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -593,7 +593,7 @@ private[hive] trait HiveInspectors { case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") // ideally, we don't test the foldable here(but in optimizer), however, some of the // Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly. - case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) + case _ if expr.foldable => toInspector(Literal.create(expr.eval(), expr.dataType)) // For those non constant expression, map to object inspector according to its data type case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index cd8e7c09eea5b..5be09a11ad641 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1201,7 +1201,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C CreateArray(children.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) + Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) @@ -1213,9 +1213,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedFunction(name, UnresolvedStar(None) :: Nil) /* Literals */ - case Token("TOK_NULL", Nil) => Literal(null, NullType) - case Token(TRUE(), Nil) => Literal(true, BooleanType) - case Token(FALSE(), Nil) => Literal(false, BooleanType) + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) @@ -1226,21 +1226,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C try { if (ast.getText.endsWith("L")) { // Literal bigint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) } else if (ast.getText.endsWith("S")) { // Literal smallint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) } else if (ast.getText.endsWith("Y")) { // Literal tinyint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") v = Literal(Decimal(strVal)) } else { - v = Literal(ast.getText.toDouble, DoubleType) - v = Literal(ast.getText.toLong, LongType) - v = Literal(ast.getText.toInt, IntegerType) + v = Literal.create(ast.getText.toDouble, DoubleType) + v = Literal.create(ast.getText.toLong, LongType) + v = Literal.create(ast.getText.toInt, IntegerType) } } catch { case nfe: NumberFormatException => // Do nothing diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3181cfe40016c..c482c6de8a736 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -79,9 +79,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: Literal(Array[Byte](1,2,3)) :: - Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal(Row(1,2.0d,3.0f), + Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1,2.0d,3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -166,7 +166,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) - val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType))) + val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) checkValues(constantData, constantData.zip(constantWritableOIs).map { case (d, oi) => unwrap(wrap(d, oi), oi) @@ -212,8 +212,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = row(0) :: row(0) :: Nil checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { @@ -222,7 +222,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = Map(row(0) -> row(1)) checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } } From 6562787b963204763a33e1c4e9d192db913af1fc Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 1 Apr 2015 23:42:09 -0700 Subject: [PATCH 119/171] [SPARK-6627] Some clean-up in shuffle code. Before diving into review #4450 I did a look through the existing shuffle code to learn how it works. Unfortunately, there are some very confusing things in this code. This patch makes a few small changes to simplify things. It is not easily to concisely describe the changes because of how convoluted the issues were, but they are fairly small logically: 1. There is a trait named `ShuffleBlockManager` that only deals with one logical function which is retrieving shuffle block data given shuffle block coordinates. This trait has two implementors FileShuffleBlockManager and IndexShuffleBlockManager. Confusingly the vast majority of those implementations have nothing to do with this particular functionality. So I've renamed the trait to ShuffleBlockResolver and documented it. 2. The aforementioned trait had two almost identical methods, for no good reason. I removed one method (getBytes) and modified callers to use the other one. I think the behavior is preserved in all cases. 3. The sort shuffle code uses an identifier "0" in the reduce slot of a BlockID as a placeholder. I made it into a constant since it needs to be consistent across multiple places. I think for (3) there is actually a better solution that would avoid the need to do this type of workaround/hack in the first place, but it's more complex so I'm punting it for now. Author: Patrick Wendell Closes #5286 from pwendell/cleanup and squashes the following commits: c71fbc7 [Patrick Wendell] Open interface back up for testing f36edd5 [Patrick Wendell] Code review feedback d1c0494 [Patrick Wendell] Style fix a406079 [Patrick Wendell] [HOTFIX] Some clean-up in shuffle code. --- .../shuffle/FileShuffleBlockManager.scala | 7 +---- .../shuffle/IndexShuffleBlockManager.scala | 27 +++++++++---------- ...nager.scala => ShuffleBlockResolver.scala} | 14 ++++++---- .../apache/spark/shuffle/ShuffleManager.scala | 5 +++- .../apache/spark/shuffle/ShuffleWriter.scala | 2 +- .../shuffle/hash/HashShuffleManager.scala | 8 +++--- .../shuffle/sort/SortShuffleManager.scala | 9 ++++--- .../shuffle/sort/SortShuffleWriter.scala | 6 ++--- .../apache/spark/storage/BlockManager.scala | 14 ++++------ .../util/collection/ExternalSorter.scala | 6 +++-- .../hash/HashShuffleManagerSuite.scala | 2 +- .../spark/tools/StoragePerfTester.scala | 2 +- 12 files changed, 51 insertions(+), 51 deletions(-) rename core/src/main/scala/org/apache/spark/shuffle/{ShuffleBlockManager.scala => ShuffleBlockResolver.scala} (68%) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index d0178dfde6935..5be3ed771e534 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -67,7 +67,7 @@ private[spark] trait ShuffleWriterGroup { // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) - extends ShuffleBlockManager with Logging { + extends ShuffleBlockResolver with Logging { private val transportConf = SparkTransportConf.fromSparkConf(conf) @@ -175,11 +175,6 @@ class FileShuffleBlockManager(conf: SparkConf) } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - val segment = getBlockData(blockId) - Some(segment.nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { if (consolidateShuffleFiles) { // Search all file groups associated with this shuffle. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 87fd161e06c85..50edb5a34e333 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -27,6 +27,8 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ +import IndexShuffleBlockManager.NOOP_REDUCE_ID + /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. * Data of shuffle blocks from the same map task are stored in a single consolidated data file. @@ -39,25 +41,18 @@ import org.apache.spark.storage._ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockResolver { private lazy val blockManager = SparkEnv.get.blockManager private val transportConf = SparkTransportConf.fromSparkConf(conf) - /** - * Mapping to a single shuffleBlockId with reduce ID 0. - * */ - def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = { - ShuffleBlockId(shuffleId, mapId, 0) - } - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } /** @@ -97,10 +92,6 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - Some(getBlockData(blockId).nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index @@ -123,3 +114,11 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { override def stop(): Unit = {} } + +private[spark] object IndexShuffleBlockManager { + // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort + // shuffle outputs for several reduces are glommed into a single file. + // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. + val NOOP_REDUCE_ID = 0 +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala similarity index 68% rename from core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index b521f0c7fc77e..4342b0d598b16 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -22,15 +22,19 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] -trait ShuffleBlockManager { +/** + * Implementers of this trait understand how to retrieve block data for a logical shuffle block + * identifier (i.e. map, reduce, and shuffle). Implementations may use files or file segments to + * encapsulate shuffle data. This is used by the BlockStore to abstract over different shuffle + * implementations when shuffle data is retrieved. + */ +trait ShuffleBlockResolver { type ShuffleId = Int /** - * Get shuffle block data managed by the local ShuffleBlockManager. - * @return Some(ByteBuffer) if block found, otherwise None. + * Retrieve the data for the specified block. If the data for that block is not available, + * throws an unspecified exception. */ - def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a44a8e1249256..978366d1a1d1b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -55,7 +55,10 @@ private[spark] trait ShuffleManager { */ def unregisterShuffle(shuffleId: Int): Boolean - def shuffleBlockManager: ShuffleBlockManager + /** + * Return a resolver capable of retrieving shuffle block data based on block coordinates. + */ + def shuffleBlockResolver: ShuffleBlockResolver /** Shut down this ShuffleManager. */ def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index b934480cfb9be..f6e6fe5defe09 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -23,7 +23,7 @@ import org.apache.spark.scheduler.MapStatus * Obtained inside a map task to write out records to the shuffle system. */ private[spark] trait ShuffleWriter[K, V] { - /** Write a bunch of records to this task's output */ + /** Write a sequence of records to this task's output */ def write(records: Iterator[_ <: Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 62e0629b34400..2a7df8dd5bd83 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -53,20 +53,20 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) : ShuffleWriter[K, V] = { new HashShuffleWriter( - shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - shuffleBlockManager.removeShuffle(shuffleId) + shuffleBlockResolver.removeShuffle(shuffleId) } - override def shuffleBlockManager: FileShuffleBlockManager = { + override def shuffleBlockResolver: FileShuffleBlockManager = { fileShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bda30a56d808e..0497036192154 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -58,7 +58,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) new SortShuffleWriter( - shuffleBlockManager, baseShuffleHandle, mapId, context) + shuffleBlockResolver, baseShuffleHandle, mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ @@ -66,18 +66,19 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager if (shuffleMapNumber.containsKey(shuffleId)) { val numMaps = shuffleMapNumber.remove(shuffleId) (0 until numMaps).map{ mapId => - shuffleBlockManager.removeDataByMap(shuffleId, mapId) + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override def shuffleBlockManager: IndexShuffleBlockManager = { + override def shuffleBlockResolver: IndexShuffleBlockManager = { indexShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 55ea0f17b156a..a066435df6fb0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -58,8 +58,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) + sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) sorter.insertAll(records) } @@ -67,7 +66,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) - val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) @@ -100,3 +99,4 @@ private[spark] class SortShuffleWriter[K, V, C]( } } } + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1dff09a75d038..fc31296f4deb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -301,7 +301,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -439,14 +439,10 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - val shuffleBlockManager = shuffleManager.shuffleBlockManager - shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match { - case Some(bytes) => - Some(bytes) - case None => - throw new BlockException( - blockId, s"Block $blockId not found on disk, though it should be") - } + val shuffleBlockManager = shuffleManager.shuffleBlockResolver + // TODO: This should gracefully handle case where local block is not available. Currently + // downstream code will throw an exception. + Option(shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b962c101c91da..7bd3c7852a6b2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -664,6 +664,8 @@ private[spark] class ExternalSorter[K, V, C]( } /** + * Exposed for testing purposes. + * * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one @@ -673,7 +675,7 @@ private[spark] class ExternalSorter[K, V, C]( * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer if (spills.isEmpty && partitionWriters == null) { @@ -781,7 +783,7 @@ private[spark] class ExternalSorter[K, V, C]( /** * Read a partition file back as an iterator (used in our iterator method) */ - def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { + private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { if (writer.isOpen) { writer.commitAndClose() } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 6790388f96603..b834dc0e735eb 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -54,7 +54,7 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val shuffleBlockManager = - SparkEnv.get.shuffleManager.shuffleBlockManager.asInstanceOf[FileShuffleBlockManager] + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockManager] val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf), new ShuffleWriteMetrics) diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 15ee95070a3d3..6b666a0384879 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -59,7 +59,7 @@ object StoragePerfTester { val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] def writeOutputBytes(mapId: Int, total: AtomicLong) = { - val shuffle = hashShuffleManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, + val shuffle = hashShuffleManager.shuffleBlockResolver.forMapTask(1, mapId, numOutputSplits, new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { From 424e987dfebbbaa37f4496d44090d469a931ce76 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 2 Apr 2015 17:57:01 +0800 Subject: [PATCH 120/171] [SPARK-6672][SQL] convert row to catalyst in createDataFrame(RDD[Row], ...) We assume that `RDD[Row]` contains Scala types. So we need to convert them into catalyst types in createDataFrame. liancheng Author: Xiangrui Meng Closes #5329 from mengxr/SPARK-6672 and squashes the following commits: 2d52644 [Xiangrui Meng] set needsConversion = false in jsonRDD 06896e4 [Xiangrui Meng] add createDataFrame without conversion 4a3767b [Xiangrui Meng] convert Row to catalyst --- .../spark/sql/catalyst/ScalaReflection.scala | 5 +++++ .../org/apache/spark/sql/DataFrame.scala | 3 ++- .../org/apache/spark/sql/SQLContext.scala | 20 ++++++++++++++++--- .../apache/spark/sql/parquet/newParquet.scala | 3 ++- .../apache/spark/sql/sources/commands.scala | 3 ++- .../spark/sql/test/ExamplePointUDT.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 9 ++++++++- 7 files changed, 37 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2220970085462..8bfd0471d9c7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -72,6 +72,11 @@ trait ScalaReflection { case (d: BigDecimal, _) => Decimal(d) case (d: java.math.BigDecimal, _) => Decimal(d) case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) + case (r: Row, structType: StructType) => + new GenericRow( + r.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) case (other, _) => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ce0890906bf1b..34be17325b2b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -904,7 +904,8 @@ class DataFrame private[sql]( */ override def repartition(numPartitions: Int): DataFrame = { sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema) + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), + schema, needsConversion = false) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1794936a52c6d..39dd14e796f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -392,9 +392,23 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema, needsConversion = true) + } + + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) + val catalystRows = if (needsConversion) { + rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]) + } else { + rowRDD + } + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) DataFrame(this, logicalPlan) } @@ -604,7 +618,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** @@ -633,7 +647,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 43f260d3ef8d3..e12531480ce92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -122,7 +122,8 @@ private[sql] class DefaultSource val df = sqlContext.createDataFrame( data.queryExecution.toRdd, - data.schema.asNullable) + data.schema.asNullable, + needsConversion = false) val createdRelation = createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 9bbe06e59ba30..dbdb0d39c26a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + val df = sqlContext.createDataFrame( + data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) relation.insert(df, overwrite) // Invalidate the cache. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index c11d0ae5bf1cc..2fdd798b44bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6761d996fd975..5297cc01eddfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,7 +21,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.sql @@ -506,4 +506,11 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show() testData.select($"*").show(1000) } + + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { + val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) + val df = TestSQLContext.createDataFrame(rowRDD, schema) + df.rdd.collect() + } } From 0cce5451adfc6bf4661bcf67aca3db26376455fe Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Apr 2015 12:18:33 -0700 Subject: [PATCH 121/171] [SPARK-6667] [PySpark] remove setReuseAddress The reused address on server side had caused the server can not acknowledge the connected connections, remove it. This PR will retry once after timeout, it also add a timeout at client side. Author: Davies Liu Closes #5324 from davies/collect_hang and squashes the following commits: e5a51a2 [Davies Liu] remove setReuseAddress 7977c2f [Davies Liu] do retry on client side b838f35 [Davies Liu] retry after timeout --- core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 1 - python/pyspark/rdd.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 19f4c95fcad74..36cf2af0857dd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -605,7 +605,6 @@ private[spark] object PythonRDD extends Logging { */ private def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1) - serverSocket.setReuseAddress(true) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c337a43c8a7fc..2d05611321ed6 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -113,6 +113,7 @@ def _parse_memory(s): def _load_from_socket(port, serializer): sock = socket.socket() + sock.settimeout(3) try: sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) From e3202aa2e9bd140effbcf2a7a02b90cb077e760b Mon Sep 17 00:00:00 2001 From: Hung Lin Date: Thu, 2 Apr 2015 14:01:43 -0700 Subject: [PATCH 122/171] SPARK-6414: Spark driver failed with NPE on job cancelation Use Option for ActiveJob.properties to avoid NPE bug Author: Hung Lin Closes #5124 from hunglin/SPARK-6414 and squashes the following commits: 2290b6b [Hung Lin] [SPARK-6414][core] Fix NPE in SparkContext.cancelJobGroup() --- .../scala/org/apache/spark/SparkContext.scala | 4 +--- .../apache/spark/scheduler/DAGScheduler.scala | 10 +++++----- .../org/apache/spark/SparkContextSuite.scala | 20 ++++++++++++++++++- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a70be16f77eeb..3904f7d1060c5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -433,6 +433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Thread Local variable that can be used by users to pass information down the stack private val localProperties = new InheritableThreadLocal[Properties] { override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() } /** @@ -474,9 +475,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Spark fair scheduler pool. */ def setLocalProperty(key: String, value: String) { - if (localProperties.get() == null) { - localProperties.set(new Properties()) - } if (value == null) { localProperties.get.remove(key) } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d35b4f9dbaf88..7227fa9da4317 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -493,7 +493,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = { + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -522,7 +522,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): Unit = { + properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { @@ -542,7 +542,7 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null): PartialResult[R] = { + properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray @@ -689,7 +689,7 @@ class DAGScheduler( // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() @@ -736,7 +736,7 @@ class DAGScheduler( allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) { + properties: Properties) { var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b07c4d93db4e6..c7301a30d8b11 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.util.concurrent.TimeUnit import com.google.common.base.Charsets._ import com.google.common.io.Files @@ -25,9 +26,11 @@ import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable - import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration.Duration + class SparkContextSuite extends FunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -173,4 +176,19 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { sc.stop() } } + + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) + sc.cancelJobGroup("nonExistGroupId") + Await.ready(future, Duration(2, TimeUnit.SECONDS)) + + // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause + // SparkContext to shutdown, so the following assertion will fail. + assert(sc.parallelize(1 to 10).count() == 10L) + } finally { + sc.stop() + } + } } From 4214e50fc32de1478584d8edfa3a35576c12c025 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 2 Apr 2015 16:01:03 -0700 Subject: [PATCH 123/171] [SQL] Throw UnsupportedOperationException instead of NotImplementedError NotImplementedError in scala 2.10 is a fatal exception, which is not very nice to throw when not actually fatal. Author: Michael Armbrust Closes #5315 from marmbrus/throwUnsupported and squashes the following commits: c29e03b [Michael Armbrust] [SQL] Throw UnsupportedOperationException instead of NotImplementedError 052e05b [Michael Armbrust] [SQL] Throw UnsupportedOperationException instead of NotImplementedError --- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 5 ++--- .../scala/org/apache/spark/sql/hive/StatisticsSuite.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 6bb1c47dba920..46991fbd68cde 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -184,9 +184,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.client.alterTable(tableFullName, new Table(hiveTTable)) } case otherRelation => - throw new NotImplementedError( - s"Analyze has only implemented for Hive tables, " + - s"but $tableName is a ${otherRelation.nodeName}") + throw new UnsupportedOperationException( + s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 1e05a024b8807..ccd0e5aa51f95 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -120,7 +120,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") - intercept[NotImplementedError] { + intercept[UnsupportedOperationException] { analyze("tempTable") } catalog.unregisterTable(Seq("tempTable")) From 251698fb7335a3bb465f1cd0c29e7e74e0361f4a Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 2 Apr 2015 16:02:31 -0700 Subject: [PATCH 124/171] [SPARK-6655][SQL] We need to read the schema of a data source table stored in spark.sql.sources.schema property https://issues.apache.org/jira/browse/SPARK-6655 Author: Yin Huai Closes #5313 from yhuai/SPARK-6655 and squashes the following commits: 1e00c03 [Yin Huai] Unnecessary change. f131bd9 [Yin Huai] Fix. f1218c1 [Yin Huai] Failed test. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 18 +++++++++++---- .../sql/hive/MetastoreDataSourcesSuite.scala | 23 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f0076cef13777..14cdb420731cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -70,7 +70,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val table = synchronized { client.getTable(in.database, in.name) } - val userSpecifiedSchema = + + def schemaStringFromParts: Option[String] = { Option(table.getProperty("spark.sql.sources.schema.numParts")).map { numParts => val parts = (0 until numParts.toInt).map { index => val part = table.getProperty(s"spark.sql.sources.schema.part.${index}") @@ -82,10 +83,19 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with part } - // Stick all parts back to a single schema string in the JSON representation - // and convert it back to a StructType. - DataType.fromJson(parts.mkString).asInstanceOf[StructType] + // Stick all parts back to a single schema string. + parts.mkString } + } + + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. + // After SPARK-6024, we removed this flag. + // Although we are not using spark.sql.sources.schema any more, we need to still support. + val schemaString = + Option(table.getProperty("spark.sql.sources.schema")).orElse(schemaStringFromParts) + + val userSpecifiedSchema = + schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e5ad0bf552073..e09c702c8969e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -25,6 +25,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql._ @@ -682,6 +684,27 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { assert(schema === actualSchema) } + test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { + val tableName = "spark6655" + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + // Manually create the metadata in metastore. + val tbl = new Table("default", tableName) + tbl.setProperty("spark.sql.sources.provider", "json") + tbl.setProperty("spark.sql.sources.schema", schema.json) + tbl.setProperty("EXTERNAL", "FALSE") + tbl.setTableType(TableType.MANAGED_TABLE) + tbl.setSerdeParam("path", catalog.hiveDefaultTableFilePath(tableName)) + catalog.synchronized { + catalog.client.createTable(tbl) + } + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + sql(s"drop table $tableName") + } + + test("insert into a table") { def createDF(from: Int, to: Int): DataFrame = createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") From d3944b6f2aeb36629bf89207629cc5e55d327241 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 2 Apr 2015 16:15:34 -0700 Subject: [PATCH 125/171] [Minor] [SQL] Follow-up of PR #5210 This PR addresses rxin's comments in PR #5210. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5219) Author: Cheng Lian Closes #5219 from liancheng/spark-6554-followup and squashes the following commits: 41f3a09 [Cheng Lian] Addresses comments in #5210 --- .../scala/org/apache/spark/sql/parquet/newParquet.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index e12531480ce92..583bac42fdcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -434,17 +434,18 @@ private[sql] case class ParquetRelation2( FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*) } - // Push down filters when possible. Notice that not all filters can be converted to Parquet - // filter predicate. Here we try to convert each individual predicate and only collect those - // convertible ones. + // Try to push down filters when filter push-down is enabled. if (sqlContext.conf.parquetFilterPushDown) { + val partitionColNames = partitionColumns.map(_.name).toSet predicates // Don't push down predicates which reference partition columns .filter { pred => - val partitionColNames = partitionColumns.map(_.name).toSet val referencedColNames = pred.references.map(_.name).toSet referencedColNames.intersect(partitionColNames).isEmpty } + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. .flatMap(ParquetFilters.createFilter) .reduceOption(FilterApi.and) .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) From 5db89127e72630aec7c5552f2c84018ae18d03fe Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 2 Apr 2015 16:46:50 -0700 Subject: [PATCH 126/171] [SPARK-6618][SPARK-6669][SQL] Lock Hive metastore client correctly. Author: Yin Huai Author: Michael Armbrust Closes #5333 from yhuai/lookupRelationLock and squashes the following commits: 59c884f [Michael Armbrust] [SQL] Lock metastore client in analyzeTable 7667030 [Yin Huai] Merge pull request #2 from marmbrus/pr/5333 e4a9b0b [Michael Armbrust] Correctly lock on MetastoreCatalog d6fc32f [Yin Huai] Missing `)`. 1e241af [Yin Huai] Protect InsertIntoHive. fee7e9c [Yin Huai] A test? 5416b0f [Yin Huai] Just protect client. --- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 14 +++-- .../hive/execution/InsertIntoHiveTable.scala | 51 +++++++++++-------- .../sql/hive/execution/SQLQuerySuite.scala | 11 ++++ 4 files changed, 53 insertions(+), 27 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 46991fbd68cde..7c6a7df2bd01e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -181,7 +181,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableFullName = relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName - catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + catalog.synchronized { + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } } case otherRelation => throw new UnsupportedOperationException( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 14cdb420731cd..bbd920a4051de 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -67,7 +67,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { override def load(in: QualifiedTableName): LogicalPlan = { logDebug(s"Creating new cached data source for $in") - val table = synchronized { + val table = HiveMetastoreCatalog.this.synchronized { client.getTable(in.database, in.name) } @@ -183,12 +183,16 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def lookupRelation( tableIdentifier: Seq[String], - alias: Option[String]): LogicalPlan = synchronized { + alias: Option[String]): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - val table = try client.getTable(databaseName, tblName) catch { + val table = try { + synchronized { + client.getTable(databaseName, tblName) + } + } catch { case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException => throw new NoSuchTableException } @@ -210,7 +214,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } else { val partitions: Seq[Partition] = if (table.isPartitioned) { - HiveShim.getAllPartitionsOf(client, table).toSeq + synchronized { + HiveShim.getAllPartitionsOf(client, table).toSeq + } } else { Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index cdf012b5117be..6c96747439683 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -50,7 +50,7 @@ case class InsertIntoHiveTable( @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val db = Hive.get(sc.hiveconf) + @transient private lazy val catalog = sc.catalog private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] @@ -199,38 +199,45 @@ case class InsertIntoHiveTable( orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) - db.validatePartitionNameCharacters(partVals) + catalog.synchronized { + catalog.client.validatePartitionNameCharacters(partVals) + } // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - db.loadDynamicPartitions( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - numDynamicPartitions, - holdDDLTime, - isSkewedStoreAsSubdir - ) + catalog.synchronized { + catalog.client.loadDynamicPartitions( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir) + } } else { - db.loadPartition( + catalog.synchronized { + catalog.client.loadPartition( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + } + } else { + catalog.synchronized { + catalog.client.loadTable( outputPath, qualifiedTableName, - orderedPartitionSpec, overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + holdDDLTime) } - } else { - db.loadTable( - outputPath, - qualifiedTableName, - overwrite, - holdDDLTime) } // Invalidate the cache. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 310c2bfdf1011..2065f0d60d92f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -457,4 +457,15 @@ class SQLQuerySuite extends QueryTest { dropTempTable("data") setConf("spark.sql.hive.convertCTAS", originalConf) } + + test("sanity test for SPARK-6618") { + (1 to 100).par.map { i => + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + catalog.lookupRelation(Seq(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } } From dfd2982bc7047732197f1d9ad77221e9c6076fc2 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 2 Apr 2015 17:20:31 -0700 Subject: [PATCH 127/171] [SQL][Minor] Use analyzed logical instead of unresolved in HiveComparisonTest Some internal unit test failed due to the logical plan node in pattern matching in `HiveComparisonTest`, e.g. https://github.com/apache/spark/blob/master/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala#L137 Which will may call the `output` function on an unresolved logical plan. Author: Cheng Hao Closes #4946 from chenghao-intel/logical and squashes the following commits: 432ecb3 [Cheng Hao] Use analyzed instead of logical in HiveComparisonTest --- .../apache/spark/sql/hive/execution/HiveComparisonTest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 8f3285242091c..a5ec312ee430c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -138,7 +138,7 @@ abstract class HiveComparisonTest case _ => plan.children.iterator.exists(isSorted) } - val orderedAnswer = hiveQuery.logical match { + val orderedAnswer = hiveQuery.analyzed match { // Clean out non-deterministic time schema info. // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. @@ -299,7 +299,7 @@ abstract class HiveComparisonTest val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. - hiveQueries.foreach(_.logical) + hiveQueries.foreach(_.analyzed) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { case ((queryString, i), hiveQuery, cachedAnswerFile)=> try { From 947802cb0de581e51f8141f6663e896de3d753ce Mon Sep 17 00:00:00 2001 From: DoingDone9 <799203320@qq.com> Date: Thu, 2 Apr 2015 17:23:51 -0700 Subject: [PATCH 128/171] [SPARK-6243][SQL] The Operation of match did not conside the scenarios that order.dataType does not match NativeType It did not conside that order.dataType does not match NativeType. So i add "case other => ..." for other cenarios. Author: DoingDone9 <799203320@qq.com> Closes #4959 from DoingDone9/case_ and squashes the following commits: 6278846 [DoingDone9] Update rows.scala cb1852d [DoingDone9] Merge pull request #2 from apache/master c3f046f [DoingDone9] Merge pull request #1 from apache/master --- .../scala/org/apache/spark/sql/catalyst/expressions/rows.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index a8983df208318..0a275b84086cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -224,6 +224,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case n: NativeType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case other => sys.error(s"Type $other does not support ordered operations") } if (comparison != 0) return comparison } From 052dee0707830cfd3cd8821ecc3471a37ede294a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 2 Apr 2015 18:30:55 -0700 Subject: [PATCH 129/171] [SPARK-6686][SQL] Use resolved output instead of names for toDF rename This is a workaround for a problem reported on the user list. This doesn't fix the core problem, but in general is a more robust way to do renames. Author: Michael Armbrust Closes #5337 from marmbrus/toDFrename and squashes the following commits: 6a3159d [Michael Armbrust] [SPARK-6686][SQL] Use resolved output instead of names for toDF rename --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 4 ++-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 34be17325b2b0..5c6016a4a2ce2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -240,8 +240,8 @@ class DataFrame private[sql]( s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + s"New column names (${colNames.size}): " + colNames.mkString(", ")) - val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) => - apply(oldName).as(newName) + val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => + Column(oldAttribute).as(newName) } select(newCols :_*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5297cc01eddfc..1db0cf7daac03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -60,6 +60,14 @@ class DataFrameSuite extends QueryTest { assert($"test".toString === "test") } + test("rename nested groupby") { + val df = Seq((1,(1,1))).toDF() + + checkAnswer( + df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"), + Row(1, 1) :: Nil) + } + test("invalid plan toString, debug mode") { val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") From 45134ec920c3766c22aefd4366b4b60ec99bd810 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 2 Apr 2015 19:48:55 -0700 Subject: [PATCH 130/171] [SPARK-6650] [core] Stop ExecutorAllocationManager when context stops. This fixes the thread leak. I also changed the unit test to keep track of allocated contexts and make sure they're closed after tests are run; this is needed since some tests use this pattern: val sc = createContext() doSomethingThatMayThrow() sc.stop() Author: Marcelo Vanzin Closes #5311 from vanzin/SPARK-6650 and squashes the following commits: 652c73b [Marcelo Vanzin] Nits. 5711512 [Marcelo Vanzin] More exception safety. cc5a744 [Marcelo Vanzin] Stop alloc manager before scheduler. 9886f69 [Marcelo Vanzin] [SPARK-6650] [core] Stop ExecutorAllocationManager when context stops. --- .../spark/ExecutorAllocationManager.scala | 38 ++++++++-------- .../scala/org/apache/spark/SparkContext.scala | 3 +- .../ExecutorAllocationManagerSuite.scala | 44 ++++++++++++------- 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 21c6e6ffa6666..9385f557c4614 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,10 +17,12 @@ package org.apache.spark +import java.util.concurrent.{Executors, TimeUnit} + import scala.collection.mutable import org.apache.spark.scheduler._ -import org.apache.spark.util.{SystemClock, Clock} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -129,6 +131,10 @@ private[spark] class ExecutorAllocationManager( // Listener for Spark events that impact the allocation policy private val listener = new ExecutorAllocationListener + // Executor that handles the scheduling task. + private val executor = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("spark-dynamic-executor-allocation")) + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -173,32 +179,24 @@ private[spark] class ExecutorAllocationManager( } /** - * Register for scheduler callbacks to decide when to add and remove executors. + * Register for scheduler callbacks to decide when to add and remove executors, and start + * the scheduling task. */ def start(): Unit = { listenerBus.addListener(listener) - startPolling() + + val scheduleTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + } + executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } /** - * Start the main polling thread that keeps track of when to add and remove executors. + * Stop the allocation manager. */ - private def startPolling(): Unit = { - val t = new Thread { - override def run(): Unit = { - while (true) { - try { - schedule() - } catch { - case e: Exception => logError("Exception in dynamic executor allocation thread!", e) - } - Thread.sleep(intervalMillis) - } - } - } - t.setName("spark-dynamic-executor-allocation") - t.setDaemon(true) - t.start() + def stop(): Unit = { + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3904f7d1060c5..5b3778ead6994 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1136,7 +1136,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return whether dynamically adjusting the amount of resources allocated to * this application is supported. This is currently only available for YARN. */ - private[spark] def supportDynamicAllocation = + private[spark] def supportDynamicAllocation = master.contains("yarn") || dynamicAllocationTesting /** @@ -1400,6 +1400,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli env.metricsSystem.report() metadataCleaner.cancel() cleaner.foreach(_.stop()) + executorAllocationManager.foreach(_.stop()) dagScheduler.stop() dagScheduler = null listenerBus.stop() diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index abfcee75728dc..3ded1e4af8742 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -28,10 +28,20 @@ import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { +class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ + private val contexts = new mutable.ListBuffer[SparkContext]() + + before { + contexts.clear() + } + + after { + contexts.foreach(_.stop()) + } + test("verify min/max executors") { val conf = new SparkConf() .setMaster("local") @@ -39,18 +49,19 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.testing", "true") val sc0 = new SparkContext(conf) + contexts += sc0 assert(sc0.executorAllocationManager.isDefined) sc0.stop() // Min < 0 val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") - intercept[SparkException] { new SparkContext(conf1) } + intercept[SparkException] { contexts += new SparkContext(conf1) } SparkEnv.get.stop() SparkContext.clearActiveContext() // Max < 0 val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") - intercept[SparkException] { new SparkContext(conf2) } + intercept[SparkException] { contexts += new SparkContext(conf2) } SparkEnv.get.stop() SparkContext.clearActiveContext() @@ -665,16 +676,6 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("executor-2")) assert(!removeTimes(manager).contains("executor-1")) } -} - -/** - * Helper methods for testing ExecutorAllocationManager. - * This includes methods to access private methods and fields in ExecutorAllocationManager. - */ -private object ExecutorAllocationManagerSuite extends PrivateMethodTester { - private val schedulerBacklogTimeout = 1L - private val sustainedSchedulerBacklogTimeout = 2L - private val executorIdleTimeout = 3L private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() @@ -688,9 +689,22 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { sustainedSchedulerBacklogTimeout.toString) .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString) .set("spark.dynamicAllocation.testing", "true") - new SparkContext(conf) + val sc = new SparkContext(conf) + contexts += sc + sc } +} + +/** + * Helper methods for testing ExecutorAllocationManager. + * This includes methods to access private methods and fields in ExecutorAllocationManager. + */ +private object ExecutorAllocationManagerSuite extends PrivateMethodTester { + private val schedulerBacklogTimeout = 1L + private val sustainedSchedulerBacklogTimeout = 2L + private val executorIdleTimeout = 3L + private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { new StageInfo(stageId, 0, "name", numTasks, Seq.empty, "no details") } From 4b82bd730a24f96d94dfea87420cfaa4253a5ccb Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 2 Apr 2015 20:23:08 -0700 Subject: [PATCH 131/171] [SPARK-6575][SQL] Converted Parquet Metastore tables no longer cache metadata https://issues.apache.org/jira/browse/SPARK-6575 Author: Yin Huai Closes #5339 from yhuai/parquetRelationCache and squashes the following commits: 83d9846 [Yin Huai] Remove unnecessary change. c0dc7a4 [Yin Huai] Cache converted parquet relations. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 61 +++++++++- .../apache/spark/sql/hive/parquetSuites.scala | 112 ++++++++++++++++++ 2 files changed, 167 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index bbd920a4051de..76d329a3ddcdf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -116,7 +116,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } override def refreshTable(databaseName: String, tableName: String): Unit = { - cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase) + // refresh table does not eagerly reload the cache. It just invalidate the cache. + // Next time when we use the table, it will be populated in the cache. + invalidateTable(databaseName, tableName) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -229,13 +231,42 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging - val parquetOptions = Map( - ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + val parquetOptions = Map( + ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + + def getCached( + tableIdentifier: QualifiedTableName, + pathsInMetastore: Seq[String], + schemaInMetastore: StructType, + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached Parquet Relation. + val useCached = + parquetRelation.paths == pathsInMetastore && + logical.schema.sameType(metastoreSchema) && + parquetRelation.maybePartitionSpec == partitionSpecInMetastore + + if (useCached) Some(logical) else None + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} shold be stored " + + s"as Parquet. However, we are getting a ${other} from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + } + if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) @@ -248,10 +279,28 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } val partitionSpec = PartitionSpec(partitionSchema, partitions) val paths = partitions.map(_.path) - LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } else { val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) - LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, None) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 432d65a874518..2ad6e867262b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -26,8 +26,10 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.SaveMode @@ -390,6 +392,116 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE ms_convert") } + + test("Caching converted data source Parquet Relations") { + def checkCached(tableIdentifer: catalog.QualifiedTableName): Unit = { + // Converted test_parquet should be cached. + catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { + case null => fail("Converted test_parquet should be cached in the cache.") + case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case other => + fail( + "The cached test_parquet should be a Parquet Relation. " + + s"However, $other is returned form the cache.") + } + } + + sql("DROP TABLE IF EXISTS test_insert_parquet") + sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + var tableIdentifer = catalog.QualifiedTableName("default", "test_insert_parquet") + + // First, make sure the converted test_parquet is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + // Table lookup will make the table cached. + table("test_insert_parquet") + checkCached(tableIdentifer) + // For insert into non-partitioned table, we will do the conversion, + // so the converted test_insert_parquet should be cached. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_insert_parquet + |select a, b from jt + """.stripMargin) + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select * from test_insert_parquet"), + sql("select a, b from jt").collect()) + // Invalidate the cache. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Create a partitioned table. + sql( + """ + |create table test_parquet_partitioned_cache_test + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (date string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + tableIdentifer = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-01') + |select a, b from jt + """.stripMargin) + // Right now, insert into a partitioned Parquet is not supported in data source Parquet. + // So, we expect it is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-02') + |select a, b from jt + """.stripMargin) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + + // Make sure we can cache the partitioned table. + table("test_parquet_partitioned_cache_test") + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql( + """ + |select b, '2015-04-01', a FROM jt + |UNION ALL + |select b, '2015-04-02', a FROM jt + """.stripMargin).collect()) + + invalidateTable("test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + sql("DROP TABLE test_insert_parquet") + sql("DROP TABLE test_parquet_partitioned_cache_test") + } } class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { From 8a0aa81ca37d337423db60edb09cf264cc2c6498 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei Date: Thu, 2 Apr 2015 20:24:31 -0700 Subject: [PATCH 132/171] [CORE] The descriptionof jobHistory config should be spark.history.fs.logDirectory The config option is spark.history.fs.logDirectory, not spark.fs.history.logDirectory. So the descriptionof should be changed. Thanks. Author: KaiXinXiaoLei Closes #5332 from KaiXinXiaoLei/historyConfig and squashes the following commits: 5ffbfb5 [KaiXinXiaoLei] the describe of jobHistory config is error --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 80c9c13ddec1e..9d40d8c8fd7a8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -118,7 +118,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!fs.exists(path)) { var msg = s"Log directory specified does not exist: $logDir." if (logDir == DEFAULT_LOG_DIR) { - msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + msg += " Did you configure the correct one through spark.history.fs.logDirectory?" } throw new IllegalArgumentException(msg) } From 6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302 Mon Sep 17 00:00:00 2001 From: freeman Date: Thu, 2 Apr 2015 21:37:44 -0700 Subject: [PATCH 133/171] [SPARK-6345][STREAMING][MLLIB] Fix for training with prediction This patch fixes a reported bug causing model updates to not properly propagate to model predictions during streaming regression. These minor changes in model declaration fix the problem, and I expanded the tests to include the scenario in which the bug was arising. The two new tests failed prior to the patch and now pass. cc mengxr Author: freeman Closes #5037 from freeman-lab/train-predict-fix and squashes the following commits: 3af953e [freeman] Expand test coverage to include combined training and prediction 8f84fc8 [freeman] Move model declaration --- .../StreamingLogisticRegressionWithSGD.scala | 2 ++ .../regression/StreamingLinearAlgorithm.scala | 6 ++-- .../StreamingLinearRegressionWithSGD.scala | 2 ++ .../StreamingLogisticRegressionSuite.scala | 27 ++++++++++++++++++ .../StreamingLinearRegressionSuite.scala | 28 +++++++++++++++++++ 5 files changed, 62 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index b89f38cf5aba4..7d33df3221fbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -63,6 +63,8 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( protected val algorithm = new LogisticRegressionWithSGD( stepSize, numIterations, regParam, miniBatchFraction) + protected var model: Option[LogisticRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index ce95c063db970..cea8f3f47307b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -60,7 +60,7 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: Option[M] = None + protected var model: Option[M] /** The algorithm to use for updating. */ protected val algorithm: A @@ -114,7 +114,7 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.get.predict) + data.map{x => model.get.predict(x)} } /** Java-friendly version of `predictOn`. */ @@ -132,7 +132,7 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.get.predict) + data.mapValues{x => model.get.predict(x)} } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index e5e6301127a28..a49153bf73c0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -59,6 +59,8 @@ class StreamingLinearRegressionWithSGD private[mllib] ( val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + protected var model: Option[LinearRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 8b3e6e5ce9249..d50c43d439187 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { assert(errors.forall(x => x <= 0.4)) } + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert(error.head > 0.8 & error.last < 0.2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 70b43ddb7daf5..24fd8df691817 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) assert(errors.forall(x => x <= 0.1)) } + + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert((error.head - error.last) > 2) + } } From 440ea31b76aa7e813436271fd63880c7bcd69157 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 2 Apr 2015 22:54:30 -0700 Subject: [PATCH 134/171] [SPARK-6621][Core] Fix the bug that calling EventLoop.stop in EventLoop.onReceive/onError/onStart doesn't call onStop Author: zsxwing Closes #5280 from zsxwing/SPARK-6621 and squashes the following commits: 521125e [zsxwing] Fix the bug that calling EventLoop.stop in EventLoop.onReceive and EventLoop.onError doesn't call onStop --- .../org/apache/spark/util/EventLoop.scala | 18 ++++- .../apache/spark/util/EventLoopSuite.scala | 72 +++++++++++++++++++ 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index b0ed908b84424..e9b2b8d24b476 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -76,9 +76,21 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { def stop(): Unit = { if (stopped.compareAndSet(false, true)) { eventThread.interrupt() - eventThread.join() - // Call onStop after the event thread exits to make sure onReceive happens before onStop - onStop() + var onStopCalled = false + try { + eventThread.join() + // Call onStop after the event thread exits to make sure onReceive happens before onStop + onStopCalled = true + onStop() + } catch { + case ie: InterruptedException => + Thread.currentThread().interrupt() + if (!onStopCalled) { + // ie is thrown from `eventThread.join()`. Otherwise, we should not call `onStop` since + // it's already called. + onStop() + } + } } else { // Keep quiet to allow calling `stop` multiple times. } diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 1026cb2aa7cae..47b535206c949 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -203,4 +203,76 @@ class EventLoopSuite extends FunSuite with Timeouts { assert(!eventLoop.isActive) } } + + test("EventLoop: stop() in onStart should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onStart(): Unit = { + stop() + } + + override def onReceive(event: Int): Unit = { + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onReceive should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onError should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw new RuntimeException("Oops") + } + + override def onError(e: Throwable): Unit = { + stop() + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } } From c42c3fc7f7b79a1f6ce990d39b5d9d14ab19fcf0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 3 Apr 2015 14:40:36 +0800 Subject: [PATCH 135/171] [SPARK-6575][SQL] Converted Parquet Metastore tables no longer cache metadata https://issues.apache.org/jira/browse/SPARK-6575 Author: Yin Huai This patch had conflicts when merged, resolved by Committer: Cheng Lian Closes #5339 from yhuai/parquetRelationCache and squashes the following commits: b0e1a42 [Yin Huai] Address comments. 83d9846 [Yin Huai] Remove unnecessary change. c0dc7a4 [Yin Huai] Cache converted parquet relations. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 28 +++++++++++++------ .../spark/sql/hive/execution/commands.scala | 5 ++-- .../apache/spark/sql/hive/parquetSuites.scala | 2 -- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 76d329a3ddcdf..c4da34ae645b8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -116,8 +116,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } override def refreshTable(databaseName: String, tableName: String): Unit = { - // refresh table does not eagerly reload the cache. It just invalidate the cache. + // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. + // Since we also cache ParquetRealtions converted from Hive Parquet tables and + // adding converted ParquetRealtions into the cache is not defined in the load function + // of the cache (instead, we add the cache entry in convertToParquetRelation), + // it is better at here to invalidate the cache to avoid confusing waring logs from the + // cache loader (e.g. cannot find data source provider, which is only defined for + // data source table.). invalidateTable(databaseName, tableName) } @@ -242,21 +248,27 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) def getCached( - tableIdentifier: QualifiedTableName, - pathsInMetastore: Seq[String], - schemaInMetastore: StructType, - partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + tableIdentifier: QualifiedTableName, + pathsInMetastore: Seq[String], + schemaInMetastore: StructType, + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => + case logical@LogicalRelation(parquetRelation: ParquetRelation2) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = - parquetRelation.paths == pathsInMetastore && + parquetRelation.paths.toSet == pathsInMetastore.toSet && logical.schema.sameType(metastoreSchema) && parquetRelation.maybePartitionSpec == partitionSpecInMetastore - if (useCached) Some(logical) else None + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None + } case other => logWarning( s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} shold be stored " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 4345ffbf30f77..99dc58646ddd6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -58,12 +58,13 @@ case class DropTable( try { hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { - // This table's metadata is not in + // This table's metadata is not in Hive metastore (e.g. the table does not exist). case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => // Other Throwables can be caused by users providing wrong parameters in OPTIONS // (e.g. invalid paths). We catch it and log a warning message. // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}") + case e: Throwable => log.warn(s"${e.getMessage}", e) } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 2ad6e867262b1..1319c81dfc131 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -473,7 +473,6 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test @@ -481,7 +480,6 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { |select a, b from jt """.stripMargin) assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") // Make sure we can cache the partitioned table. table("test_parquet_partitioned_cache_test") From 82701ee25fda64f03899713bc56f82ca6f278151 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 3 Apr 2015 01:25:02 -0700 Subject: [PATCH 136/171] [SPARK-6428] Turn on explicit type checking for public methods. This builds on my earlier pull requests and turns on the explicit type checking in scalastyle. Author: Reynold Xin Closes #5342 from rxin/SPARK-6428 and squashes the following commits: 7b531ab [Reynold Xin] import ordering 2d9a8a5 [Reynold Xin] jl e668b1c [Reynold Xin] override 9b9e119 [Reynold Xin] Parenthesis. 82e0cf5 [Reynold Xin] [SPARK-6428] Turn on explicit type checking for public methods. --- .../apache/spark/api/java/JavaPairRDD.scala | 2 +- .../org/apache/spark/api/java/JavaRDD.scala | 2 +- .../apache/spark/api/java/JavaRDDLike.scala | 53 ++++++++++------ .../apache/spark/examples/LocalKMeans.scala | 4 +- .../org/apache/spark/examples/LocalLR.scala | 4 +- .../org/apache/spark/examples/LogQuery.scala | 4 +- .../org/apache/spark/examples/SparkLR.scala | 4 +- .../org/apache/spark/examples/SparkTC.scala | 2 +- .../spark/examples/bagel/PageRankUtils.scala | 2 +- .../spark/examples/mllib/MovieLensALS.scala | 4 +- .../examples/streaming/ActorWordCount.scala | 6 +- .../RecoverableNetworkWordCount.scala | 3 +- .../examples/streaming/ZeroMQWordCount.scala | 6 +- .../clickstream/PageViewGenerator.scala | 2 +- .../streaming/flume/FlumeInputDStream.scala | 12 ++-- .../kafka/DirectKafkaInputDStream.scala | 5 +- .../spark/streaming/kafka/KafkaRDD.scala | 4 +- .../twitter/TwitterInputDStream.scala | 2 +- .../streaming/zeromq/ZeroMQReceiver.scala | 13 ++-- .../org/apache/spark/graphx/EdgeContext.scala | 3 +- .../apache/spark/graphx/EdgeDirection.scala | 12 ++-- .../org/apache/spark/graphx/EdgeTriplet.scala | 2 +- .../spark/graphx/impl/EdgePartition.scala | 14 ++--- .../spark/graphx/impl/EdgeRDDImpl.scala | 4 +- .../graphx/impl/ReplicatedVertexView.scala | 2 +- .../spark/graphx/impl/VertexRDDImpl.scala | 4 +- .../graphx/lib/ConnectedComponents.scala | 2 +- .../spark/graphx/lib/LabelPropagation.scala | 4 +- .../apache/spark/graphx/lib/PageRank.scala | 2 +- .../GraphXPrimitiveKeyOpenHashMap.scala | 8 +-- .../apache/spark/mllib/feature/Word2Vec.scala | 2 +- scalastyle-config.xml | 2 +- .../apache/spark/sql/AnalysisException.scala | 2 +- .../spark/sql/catalyst/analysis/package.scala | 2 +- .../org/apache/spark/sql/sources/ddl.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../org/apache/spark/sql/hive/HiveQl.scala | 3 +- .../apache/spark/streaming/Checkpoint.scala | 2 +- .../streaming/api/java/JavaDStreamLike.scala | 12 ++-- .../streaming/api/java/JavaPairDStream.scala | 2 +- .../api/java/JavaStreamingContext.scala | 10 +-- .../spark/streaming/dstream/DStream.scala | 2 +- .../tools/JavaAPICompletenessChecker.scala | 4 +- .../spark/tools/StoragePerfTester.scala | 6 +- .../spark/deploy/yarn/ApplicationMaster.scala | 62 ++++++++++--------- .../spark/deploy/yarn/ExecutorRunnable.scala | 6 +- 46 files changed, 170 insertions(+), 142 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index a023712be1166..8441bb3a3047e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -661,7 +661,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.call(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 18ccd625fc8d1..db4e996feb31c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -192,7 +192,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x) + def fn: (T) => S = (x: T) => f.call(x) import com.google.common.collect.Ordering // shadows scala.math.Ordering implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] implicit val ctag: ClassTag[S] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 8da42934a7d96..8bf0627fc420d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} +import java.util.{Comparator, List => JList, Iterator => JIterator} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -93,7 +94,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * of the original partition. */ def mapPartitionsWithIndex[R]( - f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], + f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) @@ -109,7 +110,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassTag[(K2, V2)]] + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -119,7 +120,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -129,8 +130,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala + new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -139,8 +140,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - def cm = implicitly[ClassTag[(K2, V2)]] + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -148,7 +149,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -157,7 +160,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -166,8 +171,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } + new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -175,7 +182,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -184,7 +193,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) } @@ -194,7 +205,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -277,8 +290,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def zipPartitions[U, V]( other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { - def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { + (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) } @@ -441,8 +456,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final * combine step happens locally on the master, equivalent to running a single reduce task. */ - def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + def countByValue(): java.util.Map[T, jl.Long] = + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new jl.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 17624c20cff3d..f73eac1e2b906 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -40,8 +40,8 @@ object LocalKMeans { val convergeDist = 0.001 val rand = new Random(42) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DenseVector[Double]] = { + def generatePoint(i: Int): DenseVector[Double] = { DenseVector.fill(D){rand.nextDouble * R} } Array.tabulate(N)(generatePoint) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 92a683ad57ea1..a55e0dc8d36c2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -37,8 +37,8 @@ object LocalLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 74620ad007d83..32e02eab8b031 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -54,8 +54,8 @@ object LogQuery { // scalastyle:on /** Tracks the total query count and number of aggregate bytes for a particular group. */ class Stats(val count: Int, val numBytes: Int) extends Serializable { - def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) - override def toString = "bytes=%s\tn=%s".format(numBytes, count) + def merge(other: Stats): Stats = new Stats(count + other.count, numBytes + other.numBytes) + override def toString: String = "bytes=%s\tn=%s".format(numBytes, count) } def extractKey(line: String): (String, String, String) = { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 257a7d29f922a..8c01a60844620 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -42,8 +42,8 @@ object SparkLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index f7f83086df3db..772cd897f5140 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -31,7 +31,7 @@ object SparkTC { val numVertices = 100 val rand = new Random(42) - def generateGraph = { + def generateGraph: Seq[(Int, Int)] = { val edges: mutable.Set[(Int, Int)] = mutable.Set.empty while (edges.size < numEdges) { val from = rand.nextInt(numVertices) diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index e322d4ce5a745..ab6e63deb3c95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -90,7 +90,7 @@ class PRMessage() extends Message[String] with Serializable { } class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = { val hash = key match { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 1f4ca4fbe7778..0bc36ea65e1ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -178,7 +178,9 @@ object MovieLensALS { def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) : Double = { - def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + def mapPredictedRating(r: Double): Double = { + if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + } val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) val predictionsAndRatings = predictions.map{ x => diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index b433082dce1a2..92867b44be138 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -85,13 +85,13 @@ extends Actor with ActorHelper { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) - override def preStart = remotePublisher ! SubscribeReceiver(context.self) + override def preStart(): Unit = remotePublisher ! SubscribeReceiver(context.self) - def receive = { + def receive: PartialFunction[Any, Unit] = { case msg => store(msg.asInstanceOf[T]) } - override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + override def postStop(): Unit = remotePublisher ! UnsubscribeReceiver(context.self) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index c3a05c89d817e..751b30ea15782 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -55,7 +55,8 @@ import org.apache.spark.util.IntParam */ object RecoverableNetworkWordCount { - def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = { + def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) + : StreamingContext = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 6510c70bd1866..e99d1baa72b9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -35,7 +35,7 @@ import org.apache.spark.SparkConf */ object SimpleZeroMQPublisher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: SimpleZeroMQPublisher ") System.exit(1) @@ -45,7 +45,7 @@ object SimpleZeroMQPublisher { val acs: ActorSystem = ActorSystem() val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - implicit def stringToByteString(x: String) = ByteString(x) + implicit def stringToByteString(x: String): ByteString = ByteString(x) val messages: List[ByteString] = List("words ", "may ", "count ") while (true) { Thread.sleep(1000) @@ -86,7 +86,7 @@ object ZeroMQWordCount { // Create the context and set the batch size val ssc = new StreamingContext(sparkConf, Seconds(2)) - def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator + def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator // For this stream, a zeroMQ publisher should be running. val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 8402491b62671..54d996b8ac990 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -94,7 +94,7 @@ object PageViewGenerator { while (true) { val socket = listener.accept() new Thread() { - override def run = { + override def run(): Unit = { println("Got client connected from: " + socket.getInetAddress) val out = new PrintWriter(socket.getOutputStream(), true) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2de2a7926bfd1..60e2994431b38 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -37,8 +37,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver -import org.jboss.netty.channel.ChannelPipelineFactory -import org.jboss.netty.channel.Channels +import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ @@ -187,8 +186,8 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def preferredLocation = Some(host) - + override def preferredLocation: Option[String] = Option(host) + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * @@ -198,13 +197,12 @@ class FlumeReceiver( */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { - - def getPipeline() = { + def getPipeline(): ChannelPipeline = { val pipeline = Channels.pipeline() val encoder = new ZlibEncoder(6) pipeline.addFirst("deflater", encoder) pipeline.addFirst("inflater", new ZlibDecoder()) pipeline + } } } -} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 04e65cb3d708c..1b1fc8051d052 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -129,8 +129,9 @@ class DirectKafkaInputDStream[ private[streaming] class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def batchForTime = data.asInstanceOf[mutable.HashMap[ - Time, Array[OffsetRange.OffsetRangeTuple]]] + def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = { + data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]] + } override def update(time: Time) { batchForTime.clear() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 6d465bcb6bfc0..4a83b715fa89d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -155,7 +155,7 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close() = consumer.close() + override def close(): Unit = consumer.close() override def getNext(): R = { if (iter == null || !iter.hasNext) { @@ -207,7 +207,7 @@ object KafkaRDD { fromOffsets: Map[TopicAndPartition, Long], untilOffsets: Map[TopicAndPartition, LeaderOffset], messageHandler: MessageAndMetadata[K, V] => R - ): KafkaRDD[K, V, U, T, R] = { + ): KafkaRDD[K, V, U, T, R] = { val leaders = untilOffsets.map { case (tp, lo) => tp -> (lo.host, lo.port) }.toMap diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 4eacc47da5699..7cf02d85d73d3 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -70,7 +70,7 @@ class TwitterReceiver( try { val newTwitterStream = new TwitterStreamFactory().getInstance(twitterAuth) newTwitterStream.addListener(new StatusListener { - def onStatus(status: Status) = { + def onStatus(status: Status): Unit = { store(status) } // Unimplemented diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 554705878ee78..588e6bac7b14a 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -29,13 +29,16 @@ import org.apache.spark.streaming.receiver.ActorHelper /** * A receiver to subscribe to ZeroMQ stream. */ -private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T]) +private[streaming] class ZeroMQReceiver[T: ClassTag]( + publisherUrl: String, + subscribe: Subscribe, + bytesToObjects: Seq[ByteString] => Iterator[T]) extends Actor with ActorHelper with Logging { - override def preStart() = ZeroMQExtension(context.system) - .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + override def preStart(): Unit = { + ZeroMQExtension(context.system) + .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + } def receive: Receive = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala index d8be02e2023d5..23430179f12ec 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -62,7 +62,6 @@ object EdgeContext { * , _ + _) * }}} */ - def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]) = + def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]): Some[(VertexId, VertexId, VD, VD, ED)] = Some(edge.srcId, edge.dstId, edge.srcAttr, edge.dstAttr, edge.attr) } - diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 6f03eb1439773..058c8c8aa1b24 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -34,12 +34,12 @@ class EdgeDirection private (private val name: String) extends Serializable { override def toString: String = "EdgeDirection." + name - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case other: EdgeDirection => other.name == name case _ => false } - override def hashCode = name.hashCode + override def hashCode: Int = name.hashCode } @@ -48,14 +48,14 @@ class EdgeDirection private (private val name: String) extends Serializable { */ object EdgeDirection { /** Edges arriving at a vertex. */ - final val In = new EdgeDirection("In") + final val In: EdgeDirection = new EdgeDirection("In") /** Edges originating from a vertex. */ - final val Out = new EdgeDirection("Out") + final val Out: EdgeDirection = new EdgeDirection("Out") /** Edges originating from *or* arriving at a vertex of interest. */ - final val Either = new EdgeDirection("Either") + final val Either: EdgeDirection = new EdgeDirection("Either") /** Edges originating from *and* arriving at a vertex of interest. */ - final val Both = new EdgeDirection("Both") + final val Both: EdgeDirection = new EdgeDirection("Both") } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index 9d473d5ebda44..c8790cac3d8a0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -62,7 +62,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { def vertexAttr(vid: VertexId): VD = if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr } - override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() + override def toString: String = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() def toTuple: ((VertexId, VD), (VertexId, VD), ED) = ((srcId, srcAttr), (dstId, dstAttr), attr) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 373af75448374..c561570809253 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -324,7 +324,7 @@ class EdgePartition[ * * @return an iterator over edges in the partition */ - def iterator = new Iterator[Edge[ED]] { + def iterator: Iterator[Edge[ED]] = new Iterator[Edge[ED]] { private[this] val edge = new Edge[ED] private[this] var pos = 0 @@ -351,7 +351,7 @@ class EdgePartition[ override def hasNext: Boolean = pos < EdgePartition.this.size - override def next() = { + override def next(): EdgeTriplet[VD, ED] = { val triplet = new EdgeTriplet[VD, ED] val localSrcId = localSrcIds(pos) val localDstId = localDstIds(pos) @@ -518,11 +518,11 @@ private class AggregatingEdgeContext[VD, ED, A]( _attr = attr } - override def srcId = _srcId - override def dstId = _dstId - override def srcAttr = _srcAttr - override def dstAttr = _dstAttr - override def attr = _attr + override def srcId: VertexId = _srcId + override def dstId: VertexId = _dstId + override def srcAttr: VD = _srcAttr + override def dstAttr: VD = _dstAttr + override def attr: ED = _attr override def sendToSrc(msg: A) { send(_localSrcId, msg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 43a3aea0f6196..c88b2f65a86cd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -70,9 +70,9 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 8ab255bd4038c..1df86449fa0c2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -50,7 +50,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * Return a new `ReplicatedVertexView` where edges are reversed and shipping levels are swapped to * match. */ - def reverse() = { + def reverse(): ReplicatedVertexView[VD, ED] = { val newEdges = edges.mapEdgePartitions((pid, part) => part.reverse) new ReplicatedVertexView(newEdges, hasDstId, hasSrcId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 349c8545bf201..33ac7b0ed6095 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -71,9 +71,9 @@ class VertexRDDImpl[VD] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index e2f6cc138958e..859f896039047 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -37,7 +37,7 @@ object ConnectedComponents { */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(edge: EdgeTriplet[VertexId, ED]) = { + def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = { if (edge.srcAttr < edge.dstAttr) { Iterator((edge.dstId, edge.srcAttr)) } else if (edge.srcAttr > edge.dstAttr) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 82e9e06515179..2bcf8684b8b8e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]) = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) @@ -54,7 +54,7 @@ object LabelPropagation { i -> (count1Val + count2Val) }.toMap } - def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]) = { + def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = { if (message.isEmpty) attr else message.maxBy(_._2)._1 } val initialMessage = Map[VertexId, Long]() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 570440ba4441f..042e366a29f58 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -156,7 +156,7 @@ object PageRank extends Logging { (newPR, newPR - oldPR) } - def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Double)] = { if (edge.srcAttr._2 > tol) { Iterator((edge.dstId, edge.srcAttr._2 * edge.attr)) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 57b01b6f2e1fb..e2754ea699da9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -56,7 +56,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = keySet.size + override def size: Int = keySet.size /** Get the value for a given key */ def apply(k: K): V = { @@ -112,7 +112,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -128,9 +128,9 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9ee7e4a66b535..b2d9053f70145 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -522,7 +522,7 @@ object Word2VecModel extends Loader[Word2VecModel] { new Word2VecModel(word2VecMap) } - def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]) = { + def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 459a5035d4984..7168d5b2a8e26 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -137,7 +137,7 @@ - + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 34fedead44db3..f9992185a4563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -30,7 +30,7 @@ class AnalysisException protected[sql] ( val startPosition: Option[Int] = None) extends Exception with Serializable { - def withPosition(line: Option[Int], startPosition: Option[Int]) = { + def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { val newException = new AnalysisException(message, line, startPosition) newException.setStackTrace(getStackTrace) newException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index c61c395cb4bb1..7731336d247db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -44,7 +44,7 @@ package object analysis { } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ - def withPosition[A](t: TreeNode[_])(f: => A) = { + def withPosition[A](t: TreeNode[_])(f: => A): A = { try f catch { case a: AnalysisException => throw a.withPosition(t.origin.line, t.origin.startPosition) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index eb46b46ca5bf4..319de710fbc3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -204,7 +204,7 @@ private[sql] object ResolvedDataSource { provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) - def className = clazz.getCanonicalName + def className: String = clazz.getCanonicalName val relation = userSpecifiedSchema match { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c4da34ae645b8..ae5ce4cf4c7e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -861,7 +861,7 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) - override def newInstance() = { + override def newInstance(): MetastoreRelation = { MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 5be09a11ad641..077e64133faad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -659,7 +659,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) = clause match { + def matchSerDe(clause: Seq[ASTNode]) + : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index f73b463d07779..28703ef8129b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -234,7 +234,7 @@ object CheckpointReader extends Logging { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! - def fs = checkpointPath.getFileSystem(hadoopConf) + def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 73030e15c5661..808dcc174cf9a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -169,7 +169,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -179,7 +179,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -190,7 +190,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of the RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -201,7 +203,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index f94f2d0e8bd31..93baad19e3ee1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -526,7 +526,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.apply(x).asScala implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] dstream.flatMapValues(fn) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index e3db01c1e12c6..4095a7cc84946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -192,7 +192,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -313,7 +313,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly) } @@ -344,7 +344,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) } @@ -625,7 +625,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean): Unit = ssc.stop(stopSparkContext) /** * Stop the execution of the streams. @@ -633,7 +633,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param stopGracefully Stop gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { ssc.stop(stopSparkContext, stopGracefully) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 795c5aa6d585b..24f99a2b929f5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -839,7 +839,7 @@ object DStream { /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { - def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined val isSparkClass = doesMatch(SPARK_CLASS_REGEX) val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 8d0f09933c8d3..583823c90c5c6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -17,7 +17,7 @@ package org.apache.spark.tools -import java.lang.reflect.Method +import java.lang.reflect.{Type, Method} import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -302,7 +302,7 @@ object JavaAPICompletenessChecker { private def isExcludedByInterface(method: Method): Boolean = { val excludedInterfaces = Set("org.apache.spark.Logging", "org.apache.hadoop.mapreduce.HadoopMapReduceUtil") - def toComparisionKey(method: Method) = + def toComparisionKey(method: Method): (Class[_], String, Type) = (method.getReturnType, method.getName, method.getGenericReturnType) val interfaces = method.getDeclaringClass.getInterfaces.filter { i => excludedInterfaces.contains(i.getName) diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 6b666a0384879..f2d135397ce2f 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils * Writes simulated shuffle output from several threads and records the observed throughput. */ object StoragePerfTester { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) @@ -58,7 +58,7 @@ object StoragePerfTester { val sc = new SparkContext("local[4]", "Write Tester", conf) val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] - def writeOutputBytes(mapId: Int, total: AtomicLong) = { + def writeOutputBytes(mapId: Int, total: AtomicLong): Unit = { val shuffle = hashShuffleManager.shuffleBlockResolver.forMapTask(1, mapId, numOutputSplits, new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers @@ -78,7 +78,7 @@ object StoragePerfTester { val totalBytes = new AtomicLong() for (task <- 1 to numMaps) { executor.submit(new Runnable() { - override def run() = { + override def run(): Unit = { try { writeOutputBytes(task, totalBytes) latch.countDown() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3d18690cd9cbf..455554eea0597 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -162,7 +162,7 @@ private[spark] class ApplicationMaster( * status to SUCCEEDED in cluster mode to handle if the user calls System.exit * from the application code. */ - final def getDefaultFinalStatus() = { + final def getDefaultFinalStatus(): FinalApplicationStatus = { if (isClusterMode) { FinalApplicationStatus.SUCCEEDED } else { @@ -175,31 +175,35 @@ private[spark] class ApplicationMaster( * This means the ResourceManager will not retry the application attempt on your behalf if * a failure occurred. */ - final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { - if (!unregistered) { - logInfo(s"Unregistering ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - unregistered = true - client.unregister(status, Option(diagnostics).getOrElse("")) + final def unregister(status: FinalApplicationStatus, diagnostics: String = null): Unit = { + synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) + } } } - final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { - if (!finished) { - val inShutdown = Utils.inShutdown() - logInfo(s"Final app status: ${status}, exitCode: ${code}" + - Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - exitCode = code - finalStatus = status - finalMsg = msg - finished = true - if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { - logDebug("shutting down reporter thread") - reporterThread.interrupt() - } - if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { - logDebug("shutting down user thread") - userClassThread.interrupt() + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { + synchronized { + if (!finished) { + val inShutdown = Utils.inShutdown() + logInfo(s"Final app status: $status, exitCode: $code" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code + finalStatus = status + finalMsg = msg + finished = true + if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() + } } } } @@ -506,7 +510,7 @@ private[spark] class ApplicationMaster( private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor { var driver: ActorSelection = _ - override def preStart() = { + override def preStart(): Unit = { logInfo("Listen to driver: " + driverUrl) driver = context.actorSelection(driverUrl) // Send a hello message to establish the connection, after which @@ -520,7 +524,7 @@ private[spark] class ApplicationMaster( } } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") // In cluster mode, do not rely on the disassociated event to exit @@ -567,7 +571,7 @@ object ApplicationMaster extends Logging { private var master: ApplicationMaster = _ - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { SignalLogger.register(log) val amArgs = new ApplicationMasterArguments(args) SparkHadoopUtil.get.runAsSparkUser { () => @@ -576,11 +580,11 @@ object ApplicationMaster extends Logging { } } - private[spark] def sparkContextInitialized(sc: SparkContext) = { + private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { master.sparkContextInitialized(sc) } - private[spark] def sparkContextStopped(sc: SparkContext) = { + private[spark] def sparkContextStopped(sc: SparkContext): Boolean = { master.sparkContextStopped(sc) } @@ -592,7 +596,7 @@ object ApplicationMaster extends Logging { */ object ExecutorLauncher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { ApplicationMaster.main(args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index c1d3f7320f53c..1ce10d906ab23 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -59,15 +59,15 @@ class ExecutorRunnable( val yarnConf: YarnConfiguration = new YarnConfiguration(conf) lazy val env = prepareEnvironment(container) - def run = { + override def run(): Unit = { logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() nmClient.init(yarnConf) nmClient.start() - startContainer + startContainer() } - def startContainer = { + def startContainer(): java.util.Map[String, ByteBuffer] = { logInfo("Setting up ContainerLaunchContext") val ctx = Records.newRecord(classOf[ContainerLaunchContext]) From b0d884f044fea1c954da77073f3556cd9ab1e922 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 3 Apr 2015 09:48:37 +0100 Subject: [PATCH 137/171] [SPARK-6560][CORE] Do not suppress exceptions from writer.write. If there is a failure in the Hadoop backend while calling writer.write, we should remember this original exception, and try to call writer.close(), but if that fails as well, still report the original exception. Note that, if writer.write fails, it is likely that writer was left in an invalid state, and so actually makes it more likely that writer.close will also fail. Which just increases the chances for writer.write's exception to be suppressed. This patch introduces an admittedly potentially too cute Utils.tryWithSafeFinally method to handle the try/finally gyrations. Author: Stephen Haberman Closes #5223 from stephenh/do_not_suppress_writer_exception and squashes the following commits: c7ad53f [Stephen Haberman] [SPARK-6560][CORE] Do not suppress exceptions from writer.write. --- .../org/apache/spark/MapOutputTracker.scala | 11 +++-- .../apache/spark/api/python/PythonRDD.scala | 8 ++-- .../spark/broadcast/HttpBroadcast.scala | 19 +++++--- .../master/FileSystemPersistenceEngine.scala | 5 +- .../deploy/rest/StandaloneRestClient.scala | 8 +++- .../org/apache/spark/rdd/CheckpointRDD.scala | 8 +++- .../apache/spark/rdd/PairRDDFunctions.scala | 9 ++-- .../shuffle/IndexShuffleBlockManager.scala | 6 +-- .../spark/storage/BlockObjectWriter.scala | 16 ++++--- .../org/apache/spark/storage/DiskStore.scala | 18 ++++---- .../scala/org/apache/spark/util/Utils.scala | 46 +++++++++++++++++-- .../util/collection/ExternalSorter.scala | 26 ++++------- 12 files changed, 118 insertions(+), 62 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index c9426c5de23a2..5718951451afc 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -356,11 +356,14 @@ private[spark] object MapOutputTracker extends Logging { def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) + Utils.tryWithSafeFinally { + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + } { + objOut.close() } - objOut.close() out.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 36cf2af0857dd..b1ffba4c546bf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -614,9 +614,9 @@ private[spark] object PythonRDD extends Logging { try { val sock = serverSocket.accept() val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - try { + Utils.tryWithSafeFinally { writeIteratorToStream(items, out) - } finally { + } { out.close() } } catch { @@ -862,9 +862,9 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial val file = File.createTempFile("broadcast", "", dir) path = file.getAbsolutePath val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { Utils.copyStream(in, out) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 74ccfa6d3c9a3..4457c75e8b0fc 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -165,7 +165,7 @@ private[broadcast] object HttpBroadcast extends Logging { private def write(id: Long, value: Any) { val file = getFile(id) val fileOutputStream = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(fileOutputStream) @@ -175,10 +175,13 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() + Utils.tryWithSafeFinally { + serOut.writeObject(value) + } { + serOut.close() + } files += file - } finally { + } { fileOutputStream.close() } } @@ -212,9 +215,11 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj + Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 32499b3a784a1..f459ed5b3a1a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import akka.serialization.Serialization import org.apache.spark.Logging +import org.apache.spark.util.Utils /** @@ -59,9 +60,9 @@ private[master] class FileSystemPersistenceEngine( val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { out.write(serialized) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 420442f7564cc..a3539e44bd2f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -27,6 +27,7 @@ import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils /** * A client that submits applications to the standalone Master using a REST protocol. @@ -148,8 +149,11 @@ private[deploy] class StandaloneRestClient extends Logging { conn.setRequestProperty("charset", "utf-8") conn.setDoOutput(true) val out = new DataOutputStream(conn.getOutputStream) - out.write(json.getBytes(Charsets.UTF_8)) - out.close() + Utils.tryWithSafeFinally { + out.write(json.getBytes(Charsets.UTF_8)) + } { + out.close() + } readResponse(conn) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 1c13e2c372845..760c0fa3ac96a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -112,8 +113,11 @@ private[spark] object CheckpointRDD extends Logging { } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) - serializeStream.writeAll(iterator) - serializeStream.close() + Utils.tryWithSafeFinally { + serializeStream.writeAll(iterator) + } { + serializeStream.close() + } if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 6b4f097ea9ae5..bf1303d39592d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -995,7 +995,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L - try { + Utils.tryWithSafeFinally { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1004,7 +1004,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close(hadoopContext) } committer.commitTask(hadoopContext) @@ -1068,7 +1068,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() var recordsWritten = 0L - try { + + Utils.tryWithSafeFinally { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) @@ -1077,7 +1078,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close() } writer.commit() diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 50edb5a34e333..a1741e2875c16 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ +import org.apache.spark.util.Utils import IndexShuffleBlockManager.NOOP_REDUCE_ID @@ -78,16 +79,15 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockResolver { def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { + Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L out.writeLong(offset) - for (length <- lengths) { offset += length out.writeLong(offset) } - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index f703e50b6b0ac..0dfc91dfaff85 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -23,6 +23,7 @@ import java.nio.channels.FileChannel import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -140,14 +141,17 @@ private[spark] class DiskBlockObjectWriter( override def close() { if (initialized) { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() - callWithTiming { - fos.getFD.sync() + Utils.tryWithSafeFinally { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + callWithTiming { + fos.getFD.sync() + } } + } { + objOut.close() } - objOut.close() channel = null bs = null diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 61ef5ff168791..4b232ae7d3180 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -46,10 +46,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel - while (bytes.remaining > 0) { - channel.write(bytes) + Utils.tryWithSafeFinally { + while (bytes.remaining > 0) { + channel.write(bytes) + } + } { + channel.close() } - channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) @@ -75,9 +78,9 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) try { - try { + Utils.tryWithSafeFinally { blockManager.dataSerializeStream(blockId, outputStream, values) - } finally { + } { // Close outputStream here because it should be closed before file is deleted. outputStream.close() } @@ -106,8 +109,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { val channel = new RandomAccessFile(file, "r").getChannel - - try { + Utils.tryWithSafeFinally { // For small files, directly read rather than memory map if (length < minMemoryMapBytes) { val buf = ByteBuffer.allocate(length.toInt) @@ -123,7 +125,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } else { Some(channel.map(MapMode.READ_ONLY, offset, length)) } - } finally { + } { channel.close() } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bb8bd1015668a..7c85e28679f1d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -313,7 +313,7 @@ private[spark] object Utils extends Logging { transferToEnabled: Boolean = false): Long = { var count = 0L - try { + tryWithSafeFinally { if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. @@ -353,7 +353,7 @@ private[spark] object Utils extends Logging { } } count - } finally { + } { if (closeStreams) { try { in.close() @@ -1214,6 +1214,44 @@ private[spark] object Utils extends Logging { } } + /** + * Execute a block of code, then a finally block, but if exceptions happen in + * the finally block, do not suppress the original exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { + // It would be nice to find a method on Try that did this + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable => + if (originalThrowable != null) { + // We could do originalThrowable.addSuppressed(t), but it's + // not available in JDK 1.6. + logWarning(s"Suppressing exception in finally: " + t.getMessage, t) + throw originalThrowable + } else { + throw t + } + } + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def coreExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the "core" Spark API that we want to skip when @@ -2074,7 +2112,7 @@ private[spark] class RedirectThread( override def run() { scala.util.control.Exception.ignoring(classOf[IOException]) { // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - try { + Utils.tryWithSafeFinally { val buf = new Array[Byte](1024) var len = in.read(buf) while (len != -1) { @@ -2082,7 +2120,7 @@ private[spark] class RedirectThread( out.flush() len = in.read(buf) } - } finally { + } { if (propagateEof) { out.close() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7bd3c7852a6b2..035f3767ff554 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -728,25 +728,19 @@ private[spark] class ExternalSorter[K, V, C]( // this simple we spill out the current in-memory collection so that everything is in files. spillToPartitionFiles(if (aggregator.isDefined) map else buffer) partitionWriters.foreach(_.commitAndClose()) - var out: FileOutputStream = null - var in: FileInputStream = null + val out = new FileOutputStream(outputFile, true) val writeStartTime = System.nanoTime - try { - out = new FileOutputStream(outputFile, true) + util.Utils.tryWithSafeFinally { for (i <- 0 until numPartitions) { - in = new FileInputStream(partitionWriters(i).fileSegment().file) - val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) - in.close() - in = null - lengths(i) = size - } - } finally { - if (out != null) { - out.close() - } - if (in != null) { - in.close() + val in = new FileInputStream(partitionWriters(i).fileSegment().file) + util.Utils.tryWithSafeFinally { + lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) + } { + in.close() + } } + } { + out.close() context.taskMetrics.shuffleWriteMetrics.foreach( _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } From b52c7f9fc87a1b9a039724e1dac8b30554f75196 Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Fri, 3 Apr 2015 10:26:43 +0100 Subject: [PATCH 138/171] [MLLIB] Remove println in LogisticRegression.scala There's no corresponding printing in linear regression. Here was my previous PR (something weird happened and I can't reopen it) https://github.com/apache/spark/pull/5272 Author: Omede Firouz Closes #5338 from oefirouz/println and squashes the following commits: 3f3dbf4 [Omede Firouz] [MLLIB] Remove println --- .../org/apache/spark/ml/classification/LogisticRegression.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21f61d80dd95a..49c00f77480e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -180,7 +180,6 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override protected def predict(features: Vector): Double = { - println(s"LR.predict with threshold: ${paramMap(threshold)}") if (score(features) > paramMap(threshold)) 1 else 0 } From 512a2f191a6b53699373b6588f316b4437050425 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Fri, 3 Apr 2015 09:49:50 -0700 Subject: [PATCH 139/171] [SPARK-6615][MLLIB] Python API for Word2Vec This is the sub-task of SPARK-6254. Wrap missing method for `Word2Vec` and `Word2VecModel`. Author: lewuathe Closes #5296 from Lewuathe/SPARK-6615 and squashes the following commits: f14c304 [lewuathe] Reorder tests 1d326b9 [lewuathe] Merge master e2bedfb [lewuathe] Modify test cases afb866d [lewuathe] [SPARK-6615] Python API for Word2Vec --- .../mllib/api/python/PythonMLLibAPI.scala | 8 +++- python/pyspark/mllib/feature.py | 18 +++++++- python/pyspark/mllib/tests.py | 45 ++++++++++++++++--- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 5995d6df97c15..6c386cacfb7ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -476,13 +476,15 @@ private[python] class PythonMLLibAPI extends Serializable { learningRate: Double, numPartitions: Int, numIterations: Int, - seed: Long): Word2VecModelWrapper = { + seed: Long, + minCount: Int): Word2VecModelWrapper = { val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) + .setMinCount(minCount) try { val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) new Word2VecModelWrapper(model) @@ -516,6 +518,10 @@ private[python] class PythonMLLibAPI extends Serializable { val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } } /** diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 4bfe3014ef748..3cda1205e1391 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -337,6 +337,12 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + def getVectors(self): + """ + Returns a map of words to their vector representations. + """ + return self.call("getVectors") + class Word2Vec(object): """ @@ -379,6 +385,7 @@ def __init__(self): self.numPartitions = 1 self.numIterations = 1 self.seed = random.randint(0, sys.maxint) + self.minCount = 5 def setVectorSize(self, vectorSize): """ @@ -417,6 +424,14 @@ def setSeed(self, seed): self.seed = seed return self + def setMinCount(self, minCount): + """ + Sets minCount, the minimum number of times a token must appear + to be included in the word2vec model's vocabulary (default: 5). + """ + self.minCount = minCount + return self + def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -428,7 +443,8 @@ def fit(self, data): raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed)) + int(self.numIterations), long(self.seed), + int(self.minCount)) return Word2VecModel(jmodel) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 6e9c68ec8a5c1..dd3b66ce67457 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -42,6 +42,7 @@ from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext @@ -630,6 +631,12 @@ def test_right_number_of_results(self): self.assertIsNotNone(chi[1000]) +class SerDeTest(PySparkTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + class FeatureTest(PySparkTestCase): def test_idf_model(self): data = [ @@ -643,11 +650,39 @@ def test_idf_model(self): self.assertEqual(len(idf), 11) -class SerDeTest(PySparkTestCase): - def test_to_java_object_rdd(self): # SPARK-6660 - data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) - self.assertEqual(_to_java_object_rdd(data).count(), 10) - +class Word2VecTests(PySparkTestCase): + def test_word2vec_setters(self): + data = [ + ["I", "have", "a", "pen"], + ["I", "like", "soccer", "very", "much"], + ["I", "live", "in", "Tokyo"] + ] + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) + self.assertEquals(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEquals(model.numPartitions, 2) + self.assertEquals(model.numIterations, 10) + self.assertEquals(model.seed, 1024) + self.assertEquals(model.minCount, 3) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEquals(len(model.getVectors()), 3) if __name__ == "__main__": if not _have_scipy: From dc6dff248d8f5d7de22af64b0586dfe3885731df Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 3 Apr 2015 18:31:48 +0100 Subject: [PATCH 140/171] [Minor][SQL] Fix typo Just fix a typo. Author: Liang-Chi Hsieh Closes #5352 from viirya/fix_a_typo and squashes the following commits: 303b2d2 [Liang-Chi Hsieh] Fix typo. --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ae5ce4cf4c7e7..315fab673da5c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -271,7 +271,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } case other => logWarning( - s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} shold be stored " + + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + s"as Parquet. However, we are getting a ${other} from the metastore cache. " + s"This cached entry will be invalidated.") cachedDataSourceTables.invalidate(tableIdentifier) From c23ba81b8cf86c3a085de8ddfef9403ff6fcd87f Mon Sep 17 00:00:00 2001 From: guowei2 Date: Sat, 4 Apr 2015 02:02:30 +0800 Subject: [PATCH 141/171] [SPARK-5203][SQL] fix union with different decimal type When union non-decimal types with decimals, we use the following rules: - FIRST `intTypeToFixed`, then fixed union decimals with precision/scale p1/s2 and p2/s2 will be promoted to DecimalType(max(p1, p2), max(s1, s2)) - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, but note that unlimited decimals are considered bigger than doubles in WidenTypes) Author: guowei2 Closes #4004 from guowei2/SPARK-5203 and squashes the following commits: ff50f5f [guowei2] fix code style 11df1bf [guowei2] fix decimal union with double, double->Decimal(15,15) 0f345f9 [guowei2] fix structType merge with decimal 101ed4d [guowei2] fix build error after rebase 0b196e4 [guowei2] code style fe2c2ca [guowei2] handle union decimal precision in 'DecimalPrecision' 421d840 [guowei2] fix union types for decimal precision ef2c661 [guowei2] fix union with different decimal type --- .../catalyst/analysis/HiveTypeCoercion.scala | 190 ++++++++++++------ .../apache/spark/sql/types/dataTypes.scala | 5 +- .../analysis/DecimalPrecisionSuite.scala | 30 ++- .../sql/hive/execution/SQLQuerySuite.scala | 11 + 4 files changed, 167 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 9a33eb145273e..3aeb964994d37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -285,6 +285,7 @@ trait HiveTypeCoercion { * Calculates and propagates precision for fixed-precision decimals. Hive has a number of * rules for this based on the SQL standard and MS SQL: * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx * * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 * respectively, then the following operations have the following precision / scale: @@ -296,6 +297,7 @@ trait HiveTypeCoercion { * e1 * e2 p1 + p2 + 1 s1 + s2 * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 * @@ -311,7 +313,12 @@ trait HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * - FLOAT and DOUBLE + * 1. Union operation: + * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the + * same as Hive) + * 2. Other operation: + * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, * but note that unlimited decimals are considered bigger than doubles in WidenTypes) */ // scalastyle:on @@ -328,76 +335,127 @@ trait HiveTypeCoercion { def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes whose children have not been resolved yet - case e if !e.childrenResolved => e + // Conversion rules for float and double into fixed-precision decimals + val floatTypeToFixed: Map[DataType, DecimalType] = Map( + FloatType -> DecimalType(7, 7), + DoubleType -> DecimalType(15, 15) + ) - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) - - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) - - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) - - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) - case _ => - b + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // fix decimal precision for union + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val castedInput = left.output.zip(right.output).map { + case (l, r) if l.dataType != r.dataType => + (l.dataType, r.dataType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to + // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) + val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) + (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => + (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => + (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) + case _ => (l, r) + } + case other => other } - // TODO: MaxOf, MinOf, etc might want other rules + val (castedLeft, castedRight) = castedInput.unzip - // SUM and AVERAGE are handled by the implementations of those expressions + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + Project(castedLeft, left) + } else { + left + } + + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + Project(castedRight, right) + } else { + right + } + + Union(newLeft, newRight) + + // fix decimal precision for expressions + case q => q.transformExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + case LessThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 952cf5c75688d..cdf2bc68d9c5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import scala.math._ import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} @@ -934,7 +935,9 @@ object StructType { case (DecimalType.Fixed(leftPrecision, leftScale), DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale)) + DecimalType( + max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), + max(leftScale, rightScale)) case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bc2ec754d5865..67bec999dfbd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), AttributeReference("u", DecimalType.Unlimited)(), - AttributeReference("f", FloatType)() + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() ) val i: Expression = UnresolvedAttribute("i") @@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val d2: Expression = UnresolvedAttribute("d2") val u: Expression = UnresolvedAttribute("u") val f: Expression = UnresolvedAttribute("f") + val b: Expression = UnresolvedAttribute("b") before { catalog.registerTable(Seq("table"), relation) @@ -58,6 +60,17 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { assert(comparison.right.dataType === expectedType) } + private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = { + val plan = + Union(Project(Seq(Alias(left, "l")()), relation), + Project(Seq(Alias(right, "r")()), relation)) + val (l, r) = analyzer(plan).collect { + case Union(left, right) => (left.output.head, right.output.head) + }.head + assert(l.dataType === expectedType) + assert(r.dataType === expectedType) + } + test("basic operations") { checkType(Add(d1, d2), DecimalType(6, 2)) checkType(Subtract(d1, d2), DecimalType(6, 2)) @@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } + test("decimal precision for union") { + checkUnion(d1, i, DecimalType(11, 1)) + checkUnion(i, d2, DecimalType(12, 2)) + checkUnion(d1, d2, DecimalType(5, 2)) + checkUnion(d2, d1, DecimalType(5, 2)) + checkUnion(d1, f, DecimalType(8, 7)) + checkUnion(f, d2, DecimalType(10, 7)) + checkUnion(d1, b, DecimalType(16, 15)) + checkUnion(b, d2, DecimalType(18, 15)) + checkUnion(d1, u, DecimalType.Unlimited) + checkUnion(u, d2, DecimalType.Unlimited) + } + test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) checkType(Add(d1, f), DoubleType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2065f0d60d92f..817b9dcb8f505 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -468,4 +468,15 @@ class SQLQuerySuite extends QueryTest { sql(s"DROP TABLE $tableName") } } + + test("SPARK-5203 union with different decimal precision") { + Seq.empty[(Decimal, Decimal)] + .toDF("d1", "d2") + .select($"d1".cast(DecimalType(10, 15)).as("d")) + .registerTempTable("dn") + + sql("select d from dn union all select d * 2 from dn") + .queryExecution.analyzed + } + } From 2c43ea38ee0db6b292c14baf6bc6f8d16f509c9d Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Fri, 3 Apr 2015 19:23:11 +0100 Subject: [PATCH 142/171] [SPARK-6492][CORE] SparkContext.stop() can deadlock when DAGSchedulerEventProcessLoop dies I've added a timeout and retry loop around the SparkContext shutdown code that should fix this deadlock. If a SparkContext shutdown is in progress when another thread comes knocking, it will wait for 10 seconds for the lock, then fall through where the outer loop will re-submit the request. Author: Ilya Ganelin Closes #5277 from ilganeli/SPARK-6492 and squashes the following commits: 8617a7e [Ilya Ganelin] Resolved merge conflict 2fbab66 [Ilya Ganelin] Added MIMA Exclude a0e2c70 [Ilya Ganelin] Deleted stale imports fa28ce7 [Ilya Ganelin] reverted to just having a single stopped 76fc825 [Ilya Ganelin] Updated to use atomic booleans instead of the synchronized vars 6e8a7f7 [Ilya Ganelin] Removing unecessary null check for now since i'm not fixing stop ordering yet cdf7073 [Ilya Ganelin] [SPARK-6492] Moved stopped=true back to the start of the shutdown sequence so this can be addressed in a seperate PR 7fb795b [Ilya Ganelin] Spacing b7a0c5c [Ilya Ganelin] Import ordering df8224f [Ilya Ganelin] Added comment for added lock 343cb94 [Ilya Ganelin] [SPARK-6492] Added timeout/retry logic to fix a deadlock in SparkContext shutdown --- .../scala/org/apache/spark/SparkContext.scala | 59 ++++++++++--------- project/MimaExcludes.scala | 4 ++ 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5b3778ead6994..abf81e312d8e6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -23,7 +23,7 @@ import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID import scala.collection.{Map, Set} @@ -95,10 +95,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - @volatile private var stopped: Boolean = false + private val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("Cannot call methods on a stopped SparkContext") } } @@ -1390,33 +1390,34 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli addedJars.clear() } - /** Shut down the SparkContext. */ + // Shut down the SparkContext. def stop() { - SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - if (!stopped) { - stopped = true - postApplicationEnd() - ui.foreach(_.stop()) - env.metricsSystem.report() - metadataCleaner.cancel() - cleaner.foreach(_.stop()) - executorAllocationManager.foreach(_.stop()) - dagScheduler.stop() - dagScheduler = null - listenerBus.stop() - eventLogger.foreach(_.stop()) - env.actorSystem.stop(heartbeatReceiver) - progressBar.foreach(_.stop()) - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - logInfo("Successfully stopped SparkContext") - SparkContext.clearActiveContext() - } else { - logInfo("SparkContext already stopped") - } + // Use the stopping variable to ensure no contention for the stop scenario. + // Still track the stopped variable for use elsewhere in the code. + + if (!stopped.compareAndSet(false, true)) { + logInfo("SparkContext already stopped.") + return } + + postApplicationEnd() + ui.foreach(_.stop()) + env.metricsSystem.report() + metadataCleaner.cancel() + cleaner.foreach(_.stop()) + executorAllocationManager.foreach(_.stop()) + dagScheduler.stop() + dagScheduler = null + listenerBus.stop() + eventLogger.foreach(_.stop()) + env.actorSystem.stop(heartbeatReceiver) + progressBar.foreach(_.stop()) + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + SparkEnv.set(null) + SparkContext.clearActiveContext() + logInfo("Successfully stopped SparkContext") } @@ -1478,7 +1479,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 54500f7c2701f..c2d828f982fe0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -60,6 +60,10 @@ object MimaExcludes { ) ++ Seq( // SPARK-6510 Add a Graph#minus method acting as Set#difference ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") ) case v if v.startsWith("1.3") => From 88504b75ee610e14d7ceed8b038fa698a3d14f81 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 3 Apr 2015 11:44:27 -0700 Subject: [PATCH 143/171] [SPARK-6640][Core] Fix the race condition of creating HeartbeatReceiver and retrieving HeartbeatReceiver This PR moved the code of creating `HeartbeatReceiver` above the code of creating `schedulerBackend` to resolve the race condition. Author: zsxwing Closes #5306 from zsxwing/SPARK-6640 and squashes the following commits: 840399d [zsxwing] Don't send TaskScheduler through Akka a90616a [zsxwing] Fix docs dd202c7 [zsxwing] Fix typo d7c250d [zsxwing] Fix the race condition of creating HeartbeatReceiver and retrieving HeartbeatReceiver --- .../org/apache/spark/HeartbeatReceiver.scala | 32 +++++++++++++++---- .../scala/org/apache/spark/SparkContext.scala | 10 ++++-- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 8435e1ea2611c..9f8ad03b91e85 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -37,6 +37,12 @@ private[spark] case class Heartbeat( taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics blockManagerId: BlockManagerId) +/** + * An event that SparkContext uses to notify HeartbeatReceiver that SparkContext.taskScheduler is + * created. + */ +private[spark] case object TaskSchedulerIsSet + private[spark] case object ExpireDeadHosts private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) @@ -44,9 +50,11 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskScheduler) +private[spark] class HeartbeatReceiver(sc: SparkContext) extends Actor with ActorLogReceive with Logging { + private var scheduler: TaskScheduler = null + // executor ID -> timestamp of when the last heartbeat from this executor was received private val executorLastSeen = new mutable.HashMap[String, Long] @@ -71,12 +79,22 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule } override def receiveWithLogging: PartialFunction[Any, Unit] = { - case Heartbeat(executorId, taskMetrics, blockManagerId) => - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - executorLastSeen(executorId) = System.currentTimeMillis() - sender ! response + case TaskSchedulerIsSet => + scheduler = sc.taskScheduler + case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => + if (scheduler != null) { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + executorLastSeen(executorId) = System.currentTimeMillis() + sender ! response + } else { + // Because Executor will sleep several seconds before sending the first "Heartbeat", this + // case rarely happens. However, if it really happens, log it and ask the executor to + // register itself again. + logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet") + sender ! HeartbeatResponse(reregisterBlockManager = true) + } case ExpireDeadHosts => expireDeadHosts() } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index abf81e312d8e6..fd1838976ee22 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -356,11 +356,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val sparkUser = Utils.getCurrentUserName() executorEnvs("SPARK_USER") = sparkUser + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + private val heartbeatReceiver = env.actorSystem.actorOf( + Props(new HeartbeatReceiver(this)), "HeartbeatReceiver") + // Create and start the scheduler private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(this, taskScheduler)), "HeartbeatReceiver") + + heartbeatReceiver ! TaskSchedulerIsSet + @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) From ffe8cc9a25454ee4f451f6ee3ec6d1e934b47ca2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 3 Apr 2015 11:53:07 -0700 Subject: [PATCH 144/171] Closes #3158 From 14632b7942c02a332c4d3814fb6b2611e3f76fc7 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 3 Apr 2015 11:54:31 -0700 Subject: [PATCH 145/171] [SPARK-6688] [core] Always use resolved URIs in EventLoggingListener. Author: Marcelo Vanzin Closes #5340 from vanzin/SPARK-6688 and squashes the following commits: ccfddd9 [Marcelo Vanzin] Resolve at the source. 20d2a34 [Marcelo Vanzin] [SPARK-6688] [core] Always use resolved URIs in EventLoggingListener. --- .../scala/org/apache/spark/SparkContext.scala | 6 +++-- .../spark/deploy/ApplicationDescription.scala | 6 +++-- .../scheduler/EventLoggingListener.scala | 10 ++++----- .../history/FsHistoryProviderSuite.scala | 2 +- .../scheduler/EventLoggingListenerSuite.scala | 22 ++++++++++++------- .../spark/scheduler/ReplayListenerSuite.scala | 3 ++- 6 files changed, 30 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index fd1838976ee22..3b73a8a8fd850 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -227,9 +227,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val appName = conf.get("spark.app.name") private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[String] = { + private[spark] val eventLogDir: Option[URI] = { if (isEventLogEnabled) { - Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 3d0d68de8f495..b7ae9c1fc0a23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -17,13 +17,15 @@ package org.apache.spark.deploy +import java.net.URI + private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], val memoryPerSlave: Int, val command: Command, var appUiUrl: String, - val eventLogDir: Option[String] = None, + val eventLogDir: Option[URI] = None, // short name of compression codec used when writing event logs, if any (e.g. lzf) val eventLogCodec: Option[String] = None) extends Serializable { @@ -36,7 +38,7 @@ private[spark] class ApplicationDescription( memoryPerSlave: Int = memoryPerSlave, command: Command = command, appUiUrl: String = appUiUrl, - eventLogDir: Option[String] = eventLogDir, + eventLogDir: Option[URI] = eventLogDir, eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = new ApplicationDescription( name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir, eventLogCodec) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index c0d889360ae99..08e7727db2fde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -47,21 +47,21 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ private[spark] class EventLoggingListener( appId: String, - logBaseDir: String, + logBaseDir: URI, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = + def this(appId: String, logBaseDir: URI, sparkConf: SparkConf) = this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf) + private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) private val compressionCodec = if (shouldCompress) { Some(CompressionCodec.createCodec(sparkConf)) @@ -259,13 +259,13 @@ private[spark] object EventLoggingListener extends Logging { * @return A path which consists of file-system-safe characters. */ def getLogPath( - logBaseDir: String, + logBaseDir: URI, appId: String, compressionCodecName: Option[String] = None): String = { val sanitizedAppId = appId.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase // e.g. app_123, app_123.lzf val logName = sanitizedAppId + compressionCodecName.map { "." + _ }.getOrElse("") - Utils.resolveURI(logBaseDir).toString.stripSuffix("/") + "/" + logName + logBaseDir.toString.stripSuffix("/") + "/" + logName } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index e908ba604ebed..fcae603c7d18e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -50,7 +50,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers inProgress: Boolean, codec: Option[String] = None): File = { val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" - val logUri = EventLoggingListener.getLogPath(testDir.getAbsolutePath, appId) + val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId) val logPath = new URI(logUri).getPath + ip new File(logPath) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 448258a754153..30ee63e78d9d8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -61,7 +61,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef test("Verify log file exist") { // Verify logging directory exists val conf = getLoggingConf(testDirPath) - val eventLogger = new EventLoggingListener("test", testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener("test", testDirPath.toUri(), conf) eventLogger.start() val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) @@ -95,7 +95,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef } test("Log overwriting") { - val logUri = EventLoggingListener.getLogPath(testDir.getAbsolutePath, "test") + val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test") val logPath = new URI(logUri).getPath // Create file before writing the event log new FileOutputStream(new File(logPath)).close() @@ -107,16 +107,19 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef test("Event log name") { // without compression - assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath("/base-dir", "app1")) + assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath( + Utils.resolveURI("/base-dir"), "app1")) // with compression assert(s"file:/base-dir/app1.lzf" === - EventLoggingListener.getLogPath("/base-dir", "app1", Some("lzf"))) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), "app1", Some("lzf"))) // illegal characters in app ID assert(s"file:/base-dir/a-fine-mind_dollar_bills__1" === - EventLoggingListener.getLogPath("/base-dir", "a fine:mind$dollar{bills}.1")) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1")) // illegal characters in app ID with compression assert(s"file:/base-dir/a-fine-mind_dollar_bills__1.lz4" === - EventLoggingListener.getLogPath("/base-dir", "a fine:mind$dollar{bills}.1", Some("lz4"))) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1", Some("lz4"))) } /* ----------------- * @@ -137,7 +140,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef val conf = getLoggingConf(testDirPath, compressionCodec) extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") - val eventLogger = new EventLoggingListener(logName, testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener(logName, testDirPath.toUri(), conf) val listenerBus = new LiveListenerBus val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey") @@ -173,12 +176,15 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef * This runs a simple Spark job and asserts that the expected events are logged when expected. */ private def testApplicationEventLogging(compressionCodec: Option[String] = None) { + // Set defaultFS to something that would cause an exception, to make sure we don't run + // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) + .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath - val expectedLogDir = testDir.toURI().toString() + val expectedLogDir = testDir.toURI() assert(eventLogPath === EventLoggingListener.getLogPath( expectedLogDir, sc.applicationId, compressionCodec.map(CompressionCodec.getShortName))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 601694f57aad0..6de6d2fec622a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} +import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -145,7 +146,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { * log the events. */ private class EventMonster(conf: SparkConf) - extends EventLoggingListener("test", "testdir", conf) { + extends EventLoggingListener("test", new URI("testdir"), conf) { override def start() { } From 26b415e15970d02523f0df459557b09ffda0c8c1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 3 Apr 2015 12:35:00 -0700 Subject: [PATCH 146/171] [SPARK-6647][SQL] Make trait StringComparison as BinaryPredicate and fix unit tests of string data source Filter Now trait `StringComparison` is a `BinaryExpression`. In fact, it should be a `BinaryPredicate`. By making `StringComparison` as `BinaryPredicate`, we can throw error when a `expressions.Predicate` can't translate to a data source `Filter` in function `selectFilters`. Without this modification, because we will wrap a `Filter` outside the scanned results in `pruneFilterProjectRaw`, we can't detect about something is wrong in translating predicates to filters in `selectFilters`. The unit test of #5285 demonstrates such problem. In that pr, even `expressions.Contains` is not properly translated to `sources.StringContains`, the filtering is still performed by the `Filter` and so the test passes. Of course, by doing this modification, all `expressions.Predicate` classes need to have its data source `Filter` correspondingly. There is a small bug in `FilteredScanSuite` for doing `StringEndsWith` filter. This pr also fixes it. Author: Liang-Chi Hsieh Closes #5309 from viirya/translate_predicate and squashes the following commits: b176385 [Liang-Chi Hsieh] Address comment. 275a493 [Liang-Chi Hsieh] More properly test for StringStartsWith, StringEndsWith and StringContains. caf2347 [Liang-Chi Hsieh] Make trait StringComparison as BinaryPredicate and throw error when Predicate can't translate to data source Filter. --- .../expressions/stringOperations.scala | 11 ++++---- .../spark/sql/sources/FilteredScanSuite.scala | 28 +++++++++++++------ 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 3cdca4e9dd2d1..acfbbace608ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -156,12 +156,11 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE /** A base trait for functions that compare two strings, returning a boolean. */ trait StringComparison { - self: BinaryExpression => + self: BinaryPredicate => - type EvaluatedType = Any + override type EvaluatedType = Any override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = BooleanType def compare(l: String, r: String): Boolean @@ -184,7 +183,7 @@ trait StringComparison { * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.contains(r) } @@ -192,7 +191,7 @@ case class Contains(left: Expression, right: Expression) * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.startsWith(r) } @@ -200,7 +199,7 @@ case class StartsWith(left: Expression, right: Expression) * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.endsWith(r) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 72ddc0ea2c8cb..773bd1602d5e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -45,7 +45,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) - case "c" => (i: Int) => Seq((i - 1 + 'a').toChar.toString * 10) + case "c" => (i: Int) => + val c = (i - 1 + 'a').toChar.toString + Seq(c * 5 + c.toUpperCase() * 5) } FiltersPushed.list = filters @@ -77,7 +79,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL } def eval(a: Int) = { - val c = (a - 1 + 'a').toChar.toString * 10 + val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5 !filters.map(translateFilterOnA(_)(a)).contains(false) && !filters.map(translateFilterOnC(_)(c)).contains(false) } @@ -110,7 +112,8 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 10)).toSeq) + (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 + + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -182,15 +185,15 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", - Seq(Row(3, 3 * 2, "c" * 10))) + Seq(Row(3, 3 * 2, "c" * 5 + "C" * 5))) sqlTest( - "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'd%'", - Seq(Row(4, 4 * 2, "d" * 10))) + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", + Seq(Row(4, 4 * 2, "d" * 5 + "D" * 5))) sqlTest( - "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%e%'", - Seq(Row(5, 5 * 2, "e" * 10))) + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", + Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5))) testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) @@ -222,6 +225,15 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution From 9b40c17ab161b64933539abeefde443cb4f98673 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 3 Apr 2015 15:22:21 -0700 Subject: [PATCH 147/171] [SPARK-6700] disable flaky test Author: Davies Liu Closes #5356 from davies/flaky and squashes the following commits: 08955f4 [Davies Liu] disable flaky test --- .../scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 0e37276ba724b..c06c0105670c0 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -143,7 +143,8 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } } - test("run Python application in yarn-cluster mode") { + // Enable this once fix SPARK-6700 + ignore("run Python application in yarn-cluster mode") { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) val pyFile = new File(tempDir, "test2.py") From da25c86d64ff9ce80f88186ba083f6c21dd9a568 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 4 Apr 2015 23:26:10 +0800 Subject: [PATCH 148/171] [SQL] Use path.makeQualified in newParquet. Author: Yin Huai Closes #5353 from yhuai/wrongFS and squashes the following commits: 849603b [Yin Huai] Not use deprecated method. 6d6ae34 [Yin Huai] Use path.makeQualified. --- .../main/scala/org/apache/spark/sql/parquet/newParquet.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 583bac42fdcce..0dce3623a66df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -268,7 +268,8 @@ private[sql] case class ParquetRelation2( // containing Parquet files (e.g. partitioned Parquet table). val baseStatuses = paths.distinct.map { p => val fs = FileSystem.get(URI.create(p), sparkContext.hadoopConfiguration) - val qualified = fs.makeQualified(new Path(p)) + val path = new Path(p) + val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) if (!fs.exists(qualified) && maybeSchema.isDefined) { fs.mkdirs(qualified) From 7bca62f79056e592cf07b49d8b8d04c59dea25fc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 5 Apr 2015 00:20:43 +0800 Subject: [PATCH 149/171] [SPARK-6607][SQL] Check invalid characters for Parquet schema and show error messages '(' and ')' are special characters used in Parquet schema for type annotation. When we run an aggregation query, we will obtain attribute name such as "MAX(a)". If we directly store the generated DataFrame as Parquet file, it causes failure when reading and parsing the stored schema string. Several methods can be adopted to solve this. This pr uses a simplest one to just replace attribute names before generating Parquet schema based on these attributes. Another possible method might be modifying all aggregation expression names from "func(column)" to "func[column]". Author: Liang-Chi Hsieh Closes #5263 from viirya/parquet_aggregation_name and squashes the following commits: 2d70542 [Liang-Chi Hsieh] Address comment. 463dff4 [Liang-Chi Hsieh] Instead of replacing special chars, showing error message to user to suggest using Alias. 1de001d [Liang-Chi Hsieh] Replace special characters '(' and ')' of Parquet schema. --- .../apache/spark/sql/parquet/ParquetTypes.scala | 14 ++++++++++++++ .../apache/spark/sql/hive/parquetSuites.scala | 16 ++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index da668f068613b..60e1bec4db8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -390,6 +390,7 @@ private[parquet] object ParquetTypesConverter extends Logging { def convertFromAttributes(attributes: Seq[Attribute], toThriftSchemaNames: Boolean = false): MessageType = { + checkSpecialCharacters(attributes) val fields = attributes.map( attribute => fromDataType(attribute.dataType, attribute.name, attribute.nullable, @@ -404,7 +405,20 @@ private[parquet] object ParquetTypesConverter extends Logging { } } + private def checkSpecialCharacters(schema: Seq[Attribute]) = { + // ,;{}()\n\t= and space character are special characters in Parquet schema + schema.map(_.name).foreach { name => + if (name.matches(".*[ ,;{}()\n\t=].*")) { + sys.error( + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + } + } + def convertToString(schema: Seq[Attribute]): String = { + checkSpecialCharacters(schema) StructType.fromAttributes(schema).json } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 1319c81dfc131..5f71e1bbc2d2e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -688,6 +688,22 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { sql("DROP TABLE alwaysNullable") } + + test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + val tempDir = Utils.createTempDir() + val filePath = new File(tempDir, "testParquet").getCanonicalPath + val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath + + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") + intercept[RuntimeException](df2.saveAsParquetFile(filePath)) + + val df3 = df2.toDF("str", "max_int") + df3.saveAsParquetFile(filePath2) + val df4 = parquetFile(filePath2) + checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) + assert(df4.columns === Array("str", "max_int")) + } } class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { From f15806a8f8ca34288ddb2d74b9ff1972c8374b59 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 4 Apr 2015 11:52:05 -0700 Subject: [PATCH 150/171] [SPARK-6602][Core] Replace direct use of Akka with Spark RPC interface - part 1 This PR replaced the following `Actor`s to `RpcEndpoint`: 1. HeartbeatReceiver 1. ExecutorActor 1. BlockManagerMasterActor 1. BlockManagerSlaveActor 1. CoarseGrainedExecutorBackend and subclasses 1. CoarseGrainedSchedulerBackend.DriverActor This is the first PR. I will split the work of SPARK-6602 to several PRs for code review. Author: zsxwing Closes #5268 from zsxwing/rpc-rewrite and squashes the following commits: 287e9f8 [zsxwing] Fix the code style 26c56b7 [zsxwing] Merge branch 'master' into rpc-rewrite 9cc825a [zsxwing] Rmove setupThreadSafeEndpoint and add ThreadSafeRpcEndpoint 30a9036 [zsxwing] Make self return null after stopping RpcEndpointRef; fix docs and error messages 705245d [zsxwing] Fix some bugs after rebasing the changes on the master 003cf80 [zsxwing] Update CoarseGrainedExecutorBackend and CoarseGrainedSchedulerBackend to use RpcEndpoint 7d0e6dc [zsxwing] Update BlockManagerSlaveActor to use RpcEndpoint f5d6543 [zsxwing] Update BlockManagerMaster to use RpcEndpoint 30e3f9f [zsxwing] Update ExecutorActor to use RpcEndpoint 478b443 [zsxwing] Update HeartbeatReceiver to use RpcEndpoint --- .../org/apache/spark/HeartbeatReceiver.scala | 66 +++++--- .../scala/org/apache/spark/SparkContext.scala | 23 +-- .../scala/org/apache/spark/SparkEnv.scala | 13 +- .../CoarseGrainedExecutorBackend.scala | 79 +++++----- .../org/apache/spark/executor/Executor.scala | 18 +-- ...utorActor.scala => ExecutorEndpoint.scala} | 18 ++- .../scala/org/apache/spark/rpc/RpcEnv.scala | 39 +++-- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 10 +- .../apache/spark/scheduler/DAGScheduler.scala | 11 +- .../cluster/CoarseGrainedClusterMessage.scala | 6 +- .../CoarseGrainedSchedulerBackend.scala | 148 +++++++++--------- .../scheduler/cluster/ExecutorData.scala | 8 +- .../cluster/SimrSchedulerBackend.scala | 13 +- .../cluster/SparkDeploySchedulerBackend.scala | 14 +- .../cluster/YarnSchedulerBackend.scala | 93 +++++------ .../mesos/CoarseMesosSchedulerBackend.scala | 4 +- .../spark/scheduler/local/LocalBackend.scala | 48 +++--- .../apache/spark/storage/BlockManager.scala | 22 +-- .../spark/storage/BlockManagerMaster.scala | 72 ++++----- ...scala => BlockManagerMasterEndpoint.scala} | 119 +++++++------- .../spark/storage/BlockManagerMessages.scala | 7 +- ....scala => BlockManagerSlaveEndpoint.scala} | 44 +++--- .../scala/org/apache/spark/util/Utils.scala | 10 ++ .../apache/spark/HeartbeatReceiverSuite.scala | 81 ++++++++++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 14 +- .../BlockManagerReplicationSuite.scala | 28 ++-- .../spark/storage/BlockManagerSuite.scala | 37 ++--- .../streaming/ReceivedBlockHandlerSuite.scala | 25 ++- .../spark/deploy/yarn/ApplicationMaster.scala | 86 +++++----- .../spark/deploy/yarn/YarnAllocator.scala | 2 +- 30 files changed, 616 insertions(+), 542 deletions(-) rename core/src/main/scala/org/apache/spark/executor/{ExecutorActor.scala => ExecutorEndpoint.scala} (67%) rename core/src/main/scala/org/apache/spark/storage/{BlockManagerMasterActor.scala => BlockManagerMasterEndpoint.scala} (83%) rename core/src/main/scala/org/apache/spark/storage/{BlockManagerSlaveActor.scala => BlockManagerSlaveEndpoint.scala} (61%) create mode 100644 core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 9f8ad03b91e85..5871b8c869f03 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,15 +17,15 @@ package org.apache.spark -import scala.concurrent.duration._ -import scala.collection.mutable +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} -import akka.actor.{Actor, Cancellable} +import scala.collection.mutable import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.Utils /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -51,9 +51,11 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) * Lives in the driver to receive heartbeats from executors.. */ private[spark] class HeartbeatReceiver(sc: SparkContext) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { + + override val rpcEnv: RpcEnv = sc.env.rpcEnv - private var scheduler: TaskScheduler = null + private[spark] var scheduler: TaskScheduler = null // executor ID -> timestamp of when the last heartbeat from this executor was received private val executorLastSeen = new mutable.HashMap[String, Long] @@ -69,34 +71,44 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000). getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)) - private var timeoutCheckingTask: Cancellable = null - - override def preStart(): Unit = { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts) - super.preStart() + private var timeoutCheckingTask: ScheduledFuture[_] = null + + private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("heartbeat-timeout-checking-thread")) + + private val killExecutorThread = Executors.newSingleThreadExecutor( + Utils.namedThreadFactory("kill-executor-thread")) + + override def onStart(): Unit = { + timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ExpireDeadHosts)) + } + }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - - override def receiveWithLogging: PartialFunction[Any, Unit] = { + + override def receive: PartialFunction[Any, Unit] = { + case ExpireDeadHosts => + expireDeadHosts() case TaskSchedulerIsSet => scheduler = sc.taskScheduler + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { val unknownExecutor = !scheduler.executorHeartbeatReceived( executorId, taskMetrics, blockManagerId) val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) executorLastSeen(executorId) = System.currentTimeMillis() - sender ! response + context.reply(response) } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to // register itself again. logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet") - sender ! HeartbeatResponse(reregisterBlockManager = true) + context.reply(HeartbeatResponse(reregisterBlockManager = true)) } - case ExpireDeadHosts => - expireDeadHosts() } private def expireDeadHosts(): Unit = { @@ -109,17 +121,25 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + s"timed out after ${now - lastSeenMs} ms")) if (sc.supportDynamicAllocation) { - sc.killExecutor(executorId) + // Asynchronously kill the executor to avoid blocking the current thread + killExecutorThread.submit(new Runnable { + override def run(): Unit = sc.killExecutor(executorId) + }) } executorLastSeen.remove(executorId) } } } - override def postStop(): Unit = { + override def onStop(): Unit = { if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() + timeoutCheckingTask.cancel(true) } - super.postStop() + timeoutCheckingThread.shutdownNow() + killExecutorThread.shutdownNow() } } + +object HeartbeatReceiver { + val ENDPOINT_NAME = "HeartbeatReceiver" +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3b73a8a8fd850..942c5975ece6d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -32,8 +32,6 @@ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} -import akka.actor.Props - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -48,12 +46,13 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.TriggerThreadDump +import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} @@ -360,14 +359,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(this)), "HeartbeatReceiver") + private val heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) // Create and start the scheduler private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - heartbeatReceiver ! TaskSchedulerIsSet + heartbeatReceiver.send(TaskSchedulerIsSet) @volatile private[spark] var dagScheduler: DAGScheduler = _ try { @@ -455,10 +454,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get + val endpointRef = env.rpcEnv.setupEndpointRef( + SparkEnv.executorActorSystemName, + RpcAddress(host, port), + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -1418,7 +1419,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli dagScheduler = null listenerBus.stop() eventLogger.foreach(_.stop()) - env.actorSystem.stop(heartbeatReceiver) + env.rpcEnv.stop(heartbeatReceiver) progressBar.foreach(_.stop()) taskScheduler = null // TODO: Cache.stop()? diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4a2ed82a40dec..55be0a59fedd9 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -295,7 +295,9 @@ object SparkEnv extends Logging { } } - def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + def registerOrLookupEndpoint( + name: String, endpointCreator: => RpcEndpoint): + RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) rpcEnv.setupEndpoint(name, endpointCreator) @@ -334,12 +336,13 @@ object SparkEnv extends Logging { new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), + conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 900e678ee02ef..8300f9f2190b9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -21,39 +21,45 @@ import java.net.URL import java.nio.ByteBuffer import scala.collection.mutable -import scala.concurrent.Await +import scala.util.{Failure, Success} -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.rpc._ +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( + override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, hostPort: String, cores: Int, userClassPath: Seq[URL], env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + @volatile var driver: Option[RpcEndpointRef] = None - override def preStart() { + override def onStart() { + import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + driver = Some(ref) + ref.sendWithReply[RegisteredExecutor.type]( + RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + } onComplete { + case Success(msg) => Utils.tryLogNonFatalError { + Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + } + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + } } def extractLogUrls: Map[String, String] = { @@ -62,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend( .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -92,23 +98,28 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - if (x.remoteAddress == driver.anchorPath.address) { - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - } else { - logWarning(s"Received irrelevant DisassociatedEvent $x") - } - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.shutdown() + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (driver.exists(_.address == remoteAddress)) { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"An unknown ($remoteAddress) driver disconnected.") + } } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + val msg = StatusUpdate(executorId, taskId, state, data) + driver match { + case Some(driverRef) => driverRef.send(msg) + case None => logWarning(s"Drop $msg because has not yet connected to driver") + } } } @@ -132,16 +143,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( + val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val driver = fetcher.setupEndpointRefByURI(driverUrl) + val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() @@ -162,16 +171,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend actor. + // Start the CoarseGrainedExecutorBackend endpoint. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, userClassPath, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index bf3135ef081c1..14f99a464b6e9 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -27,8 +27,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} @@ -88,9 +86,9 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create an RpcEndpoint for receiving RPCs from the driver + private val executorEndpoint = env.rpcEnv.setupEndpoint( + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst: Boolean = { @@ -139,7 +137,7 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.actorSystem.stop(executorActor) + env.rpcEnv.stop(executorEndpoint) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -391,11 +389,8 @@ private[spark] class Executor( } } - private val timeout = AkkaUtils.lookupTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) private val heartbeatReceiverRef = - AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { @@ -426,8 +421,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala similarity index 67% rename from core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala rename to core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala index 3e47d13f7545d..cf362f8464735 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala @@ -17,10 +17,8 @@ package org.apache.spark.executor -import akka.actor.Actor -import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils /** * Driver -> Executor message to trigger a thread dump. @@ -28,14 +26,18 @@ import org.apache.spark.util.{Utils, ActorLogReceive} private[spark] case object TriggerThreadDump /** - * Actor that runs inside of executors to enable driver -> executor RPC. + * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { +class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case TriggerThreadDump => - sender ! Utils.getThreadDump() + context.reply(Utils.getThreadDump()) } } + +object ExecutorEndpoint { + val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 7985941d949c0..d47e41abcfa50 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -40,10 +40,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement - * [[RpcEndpoint.self]]. - * - * Note: This method won't return null. `IllegalArgumentException` will be thrown if calling this - * on a non-existent endpoint. + * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. */ private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef @@ -58,20 +55,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef - /** - * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should - * make sure thread-safely sending messages to [[RpcEndpoint]]. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[RpcEndpoint]]. In the other words, changes to internal fields of a [[RpcEndpoint]] - * are visible when processing the next message, and fields in the [[RpcEndpoint]] need not be - * volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same [[RpcEndpoint]] - * for different messages. - */ - def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef - /** * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. */ @@ -181,7 +164,7 @@ private[spark] trait RpcEnvFactory { * constructor onStart receive* onStop * * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use - * [[RpcEnv.setupThreadSafeEndpoint]] + * [[ThreadSafeRpcEndpoint]] * * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. @@ -195,7 +178,7 @@ private[spark] trait RpcEndpoint { /** * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is - * called. + * called. And `self` will become `null` when `onStop` is called. * * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. @@ -278,6 +261,19 @@ private[spark] trait RpcEndpoint { } } +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +trait ThreadSafeRpcEndpoint extends RpcEndpoint + /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ @@ -407,7 +403,8 @@ private[spark] object RpcAddress { } /** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. */ private[spark] trait RpcCallContext { diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 769d59b7b3343..9e06147dff1ed 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -82,17 +82,9 @@ private[spark] class AkkaRpcEnv private[akka] ( /** * Retrieve the [[RpcEndpointRef]] of `endpoint`. */ - override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { - val endpointRef = endpointToRef.get(endpoint) - require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}") - endpointRef - } + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - setupThreadSafeEndpoint(name, endpoint) - } - - override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @volatile var endpointRef: AkkaRpcEndpointRef = null // Use lazy because the Actor needs to use `endpointRef`. // So `actorRef` should be created after assigning `endpointRef`. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7227fa9da4317..917cce1f9686c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,14 +23,10 @@ import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal -import akka.pattern.ask -import akka.util.Timeout - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -165,11 +161,8 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(blockManagerId), 600 seconds) } // Called by TaskScheduler when an executor fails. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 9bf74f4be198d..70364cea62a80 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -41,6 +42,7 @@ private[spark] object CoarseGrainedClusterMessages { // Executors to driver case class RegisterExecutor( executorId: String, + executorRef: RpcEndpointRef, hostPort: String, cores: Int, logUrls: Map[String, String]) @@ -70,6 +72,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage + // Exchanged between the driver and the AM in Yarn client mode case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage @@ -77,7 +81,7 @@ private[spark] object CoarseGrainedClusterMessages { // Messages exchanged between the driver and the cluster manager for executor allocation // In Yarn mode, these are exchanged between the driver and the AM - case object RegisterClusterManager extends CoarseGrainedClusterMessage + case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage // Request executors by specifying the new total number of executors desired // This includes executors already pending or running diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5d258d9da4d1a..4c49da87af9dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,20 +17,16 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf - private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. @@ -71,48 +66,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends ThreadSafeRpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + private val addressToExecutorId = new HashMap[RpcAddress, String] + + private val reviveThread = + Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread")) + override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } - - def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, hostPort, cores, logUrls) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor - - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls) - // This must be synchronized because variables mutated - // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { - executorDataMap.put(executorId, data) - if (numPendingExecutors > 0) { - numPendingExecutors -= 1 - logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") - } - } - listenerBus.post( - SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) - makeOffers() + reviveThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) } + }, 0, reviveInterval, TimeUnit.MILLISECONDS) + } + override def receive: PartialFunction[Any, Unit] = { case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { @@ -133,33 +106,58 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) + if (executorDataMap.contains(executorId)) { + context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + } else { + logInfo("Registered executor: " + executorRef + " with ID " + executorId) + context.reply(RegisteredExecutor) + + addressToExecutorId(executorRef.address) = executorId + totalCoreCount.addAndGet(cores) + totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) + makeOffers() + } case StopDriver => - sender ! true - context.stop(self) + context.reply(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorEndpoint.send(StopExecutor) } - sender ! true + context.reply(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + context.reply(true) case RetrieveSparkProps => - sender ! sparkProperties + context.reply(sparkProperties) } // Make fake resource offers on all executors @@ -169,6 +167,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste }.toSeq)) } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Rpc client disassociated")) + } + // Make fake resource offers on just one executor def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) @@ -199,7 +202,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } @@ -223,9 +226,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case None => logError(s"Asked to remove non-existent executor $executorId") } } + + override def onStop() { + reviveThread.shutdownNow() + } } - var driverActor: ActorRef = null + var driverEndpoint: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -236,16 +243,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverEndpoint = rpcEnv.setupEndpoint( + CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) } def stopExecutors() { try { - if (driverActor != null) { + if (driverEndpoint != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](StopExecutors) } } catch { case e: Exception => @@ -256,22 +262,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste override def stop() { stopExecutors() try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + if (driverEndpoint != null) { + driverEndpoint.askWithReply[Boolean](StopDriver) } } catch { case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) + throw new SparkException("Error stopping standalone scheduler's driver endpoint", e) } } override def reviveOffers() { - driverActor ! ReviveOffers + driverEndpoint.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -281,11 +286,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) + throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) } } @@ -391,5 +395,5 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } private[spark] object CoarseGrainedSchedulerBackend { - val ACTOR_NAME = "CoarseGrainedScheduler" + val ENDPOINT_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 5e571efe76720..26e72c0bff38d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,20 +17,20 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorEndpoint The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorEndpoint: RpcEndpointRef, + val executorAddress: RpcAddress, override val executorHost: String, var freeCores: Int, override val totalCores: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 06786a59524e7..0324c9dab910b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,16 +19,16 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.AkkaUtils private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") @@ -39,12 +39,9 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ffd4825705755..7eb3fdc19b5b8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,17 +19,18 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { @@ -48,12 +49,9 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 5a38ad9f2b12c..f72566c370a6f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{Future, ExecutionContext} -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils @@ -37,7 +35,7 @@ import scala.util.control.NonFatal private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 @@ -45,10 +43,8 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) @@ -57,16 +53,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -96,64 +90,71 @@ private[spark] abstract class YarnSchedulerBackend( } /** - * An actor that communicates with the ApplicationMaster. + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None - - implicit val askAmActorExecutor = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private val askAmThreadPool = + Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) override def receive: PartialFunction[Any, Unit] = { - case RegisterClusterManager => - logInfo(s"ApplicationMaster registered as $sender") - amActor = Some(sender) + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Some(am) + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + context.reply(am.askWithReply[Boolean](r)) } onFailure { - case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + context.reply(false) } case k: KillExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + context.reply(am.askWithReply[Boolean](k)) } onFailure { - case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + context.reply(false) } - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } + } + + override def onStop(): Unit ={ + askAmThreadPool.shutdownNow() } } } private[spark] object YarnSchedulerBackend { - val ACTOR_NAME = "YarnScheduler" + val ENDPOINT_NAME = "YarnScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e13de0f46ef89..b037a4966ced0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with Logging { @@ -148,7 +148,7 @@ private[spark] class CoarseMesosSchedulerBackend( SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.get("spark.executor.uri", null) if (uri == null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index eb3f999b5b375..70a477a6895cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,17 +18,14 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer +import java.util.concurrent.{Executors, TimeUnit} -import scala.concurrent.duration._ -import scala.language.postfixOps - -import akka.actor.{Actor, ActorRef, Props} - +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -39,17 +36,19 @@ private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() /** - * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on - * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the + * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend * and the TaskSchedulerImpl. */ -private[spark] class LocalActor( +private[spark] class LocalEndpoint( + override val rpcEnv: RpcEnv, scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { - import context.dispatcher // to use Akka's scheduler.scheduleOnce() + private val reviveThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("local-revive-thread")) private var freeCores = totalCores @@ -59,7 +58,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -87,9 +86,17 @@ private[spark] class LocalActor( } if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout - context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers) + reviveThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) + } + }, 1000, TimeUnit.MILLISECONDS) } } + + override def onStop(): Unit = { + reviveThread.shutdownNow() + } } /** @@ -101,31 +108,30 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: extends SchedulerBackend with ExecutorBackend { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + var localEndpoint: RpcEndpointRef = null override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( + "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) } override def stop() { - localActor ! StopExecutor + localEndpoint.send(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localEndpoint.send(ReviveOffers) } override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localEndpoint.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fc31296f4deb3..1aa0ef18de118 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ @@ -37,6 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -64,7 +64,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -136,9 +136,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveEndpoint = rpcEnv.setupEndpoint( + "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -167,7 +167,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -176,7 +176,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -186,7 +186,7 @@ private[spark] class BlockManager( * where it is only learned after registration with the TaskScheduler). * * This method initializes the BlockTransferService and ShuffleClient, registers with the - * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle * service if configured. */ def initialize(appId: String): Unit = { @@ -202,7 +202,7 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -265,7 +265,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) reportAllBlocks() } @@ -1215,7 +1215,7 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveEndpoint) blockInfo.clear() memoryStore.clear() diskStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 061964826f08b..ceacf043029f3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -20,35 +20,31 @@ package org.apache.spark.storage import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor._ - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverEndpoint: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" val timeout = AkkaUtils.askTimeout(conf) - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ + /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { tell(RemoveExecutor(execId)) logInfo("Removed " + execId + " successfully in removeExecutor") } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -59,7 +55,7 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( + val res = driverEndpoint.askWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logDebug(s"Updated info of block $blockId") res @@ -67,12 +63,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** @@ -85,11 +81,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -97,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - askDriverWithReply(RemoveBlock(blockId)) + driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -114,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -126,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -145,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) } /** @@ -165,11 +161,12 @@ class BlockManagerMaster( askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) /* - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. + * master endpoint for a response to a prior message. */ - val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val response = driverEndpoint. + askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -193,33 +190,28 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } - /** Stop the driver actor, called only on the Spark driver node */ + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { - if (driverActor != null && isDriver) { + if (driverEndpoint != null && isDriver) { tell(StopBlockManagerMaster) - driverActor = null + driverEndpoint = null logInfo("BlockManagerMaster stopped") } } - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") + if (!driverEndpoint.askWithReply[Boolean](message)) { + throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) - } +} +private[spark] object BlockManagerMaster { + val DRIVER_ENDPOINT_NAME = "BlockManagerMaster" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 5b5328016124e..28c73a7d543ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -21,25 +21,26 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} -import akka.actor.{Actor, ActorRef} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.Utils /** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. + * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * of all slaves' block managers. */ private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { +class BlockManagerMasterEndpoint( + override val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + listenerBus: LiveListenerBus) + extends ThreadSafeRpcEndpoint with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -50,68 +51,67 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val akkaTimeout = AkkaUtils.askTimeout(conf) + private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => + register(blockManagerId, maxMemSize, slaveEndpoint) + context.reply(true) case UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) + context.reply(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) case GetLocations(blockId) => - sender ! getLocations(blockId) + context.reply(getLocations(blockId)) case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) + context.reply(getLocationsMultipleBlockIds(blockIds)) case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) + context.reply(getPeers(blockManagerId)) - case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) + case GetRpcHostPortForExecutor(executorId) => + context.reply(getRpcHostPortForExecutor(executorId)) case GetMemoryStatus => - sender ! memoryStatus + context.reply(memoryStatus) case GetStorageStatus => - sender ! storageStatus + context.reply(storageStatus) case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) + context.reply(blockStatus(blockId, askSlaves)) case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) + context.reply(getMatchingBlockIds(filter, askSlaves)) case RemoveRdd(rddId) => - sender ! removeRdd(rddId) + context.reply(removeRdd(rddId)) case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) + context.reply(removeShuffle(shuffleId)) case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) + context.reply(removeBroadcast(broadcastId, removeFromDriver)) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) - sender ! true + context.reply(true) case RemoveExecutor(execId) => removeExecutor(execId) - sender ! true + context.reply(true) case StopBlockManagerMaster => - sender ! true - context.stop(self) + context.reply(true) + stop() case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) + context.reply(heartbeatReceived(blockManagerId)) - case other => - logWarning("Got unknown message: " + other) } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -129,22 +129,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher + // Nothing to do in the BlockManagerMasterEndpoint data structures val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) }.toSeq ) } @@ -155,14 +153,13 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } @@ -217,7 +214,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) + blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) } } } @@ -247,17 +244,16 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) /* - * Rather than blocking on the block status query, master actor should simply return + * Rather than blocking on the block status query, master endpoint should simply return * Futures to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. + * that is also waiting for this master endpoint's response to a previous message. */ blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -276,13 +272,12 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def getMatchingBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher val getMatchingBlockIds = GetMatchingBlockIds(filter) Future.sequence( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } @@ -291,7 +286,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -308,7 +303,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) + id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -379,19 +374,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Returns the hostname and port of an executor's actor system, based on the Akka address of its - * BlockManagerSlaveActor. + * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its + * [[BlockManagerSlaveEndpoint]]. */ - private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port + info <- blockManagerInfo.get(blockManagerId) ) yield { - (host, port) + (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) } } + + override def onStop(): Unit = { + askThreadPool.shutdownNow() + } } @DeveloperApi @@ -412,7 +409,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, - val slaveActor: ActorRef) + val slaveEndpoint: RpcEndpointRef) extends Logging { private var _lastSeenMs: Long = timeMs diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 48247453edef0..f89d8d7493f7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 52fb896c4e21f..8980fa8eb70e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -17,41 +17,43 @@ package org.apache.spark.storage -import scala.concurrent.Future - -import akka.actor.{ActorRef, Actor} +import scala.concurrent.{ExecutionContext, Future} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive /** - * An actor to take commands from the master to execute options. For example, + * An RpcEndpoint to take commands from the master to execute options. For example, * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor( +class BlockManagerSlaveEndpoint( + override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { + extends RpcEndpoint with Logging { - import context.dispatcher + private val asyncThreadPool = + Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => - doAsync[Boolean]("removing block " + blockId, sender) { + doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => - doAsync[Int]("removing RDD " + rddId, sender) { + doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle " + shuffleId, sender) { + doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } @@ -59,30 +61,34 @@ class BlockManagerSlaveActor( } case RemoveBroadcast(broadcastId, _) => - doAsync[Int]("removing broadcast " + broadcastId, sender) { + doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) + context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) + context.reply(blockManager.getMatchingBlockIds(filter)) } - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response - logDebug("Sent response: " + response + " to " + responseActor) + context.reply(response) + logDebug("Sent response: " + response + " to " + context.sender) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] + context.sendFailure(t) } } + + override def onStop(): Unit = { + asyncThreadPool.shutdownNow() + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7c85e28679f1d..0fdfaf300e95d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1214,6 +1214,16 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ + def tryLogNonFatalError(block: => Unit) { + try { + block + } catch { + case NonFatal(t) => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } + /** * Execute a block of code, then a finally block, but if exceptions happen in * the finally block, do not suppress the original exception. diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 0000000000000..0fd570e5297d9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito.{mock, spy, verify, when} +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.RpcUtils +import org.scalatest.concurrent.Eventually._ + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index e07bdb9637575..4f19c4f2110d2 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -311,7 +311,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } test("self: call in onStop") { - @volatile var e: Throwable = null + @volatile var selfOption: Option[RpcEndpointRef] = null val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { override val rpcEnv = env @@ -321,20 +321,18 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } override def onStop(): Unit = { - self + selfOption = Option(self) } override def onError(cause: Throwable): Unit = { - e = cause } }) env.stop(endpointRef) eventually(timeout(5 seconds), interval(10 millis)) { - // Calling `self` in `onStop` is invalid - assert(e != null) - assert(e.getMessage.contains("Cannot find RpcEndpointRef")) + // Calling `self` in `onStop` will return null, so selfOption will be None + assert(selfOption == None) } } @@ -342,7 +340,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it for(i <- 0 until 100) { @volatile var result = 0 - val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint { + val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive = { @@ -475,7 +473,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { test("network events") { val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupThreadSafeEndpoint("network-events", new RpcEndpoint { + env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive = { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c8597997..b4de90b65d545 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,11 +22,11 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService @@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -69,12 +68,10 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -83,18 +80,17 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } after { allStores.foreach { _.stop() } allStores.clear() - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd1cba5b5abe..283090e3bdb1f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,24 +19,18 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.pattern.ask -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} - import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService @@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -72,28 +66,25 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -108,9 +99,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -357,10 +348,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] + val reregister = !master.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -785,7 +774,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 18a477f92094d..ef4873de2f5a9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,20 +24,20 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.{AkkaUtils, ManualClock, Utils} +import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -54,22 +54,19 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + conf.set("spark.driver.port", rpcEnv.address.port.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") @@ -87,9 +84,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null Utils.deleteRecursively(tempDirectory) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 455554eea0597..24a1e02795218 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -24,22 +24,20 @@ import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference -import akka.actor._ -import akka.remote._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.YarnSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, - SignalLogger, Utils} +import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. @@ -72,8 +70,8 @@ private[spark] class ApplicationMaster( @volatile private var allocator: YarnAllocator = _ // Fields used in client mode. - private var actorSystem: ActorSystem = null - private var actor: ActorRef = _ + private var rpcEnv: RpcEnv = null + private var amEndpoint: RpcEndpointRef = _ // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) @@ -240,22 +238,21 @@ private[spark] class ApplicationMaster( } /** - * Create an actor that communicates with the driver. + * Create an [[RpcEndpoint]] that communicates with the driver. * * In cluster mode, the AM and the driver belong to same process - * so the AM actor need not monitor lifecycle of the driver. + * so the AMEndpoint need not monitor lifecycle of the driver. */ - private def runAMActor( + private def runAMEndpoint( host: String, port: String, isClusterMode: Boolean): Unit = { - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), + val driverEndpont = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, - host, - port, - YarnSchedulerBackend.ACTOR_NAME) - actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isClusterMode)), name = "YarnAM") + RpcAddress(host, port.toInt), + YarnSchedulerBackend.ENDPOINT_NAME) + amEndpoint = + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -272,8 +269,8 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { - actorSystem = sc.env.actorSystem - runAMActor( + rpcEnv = sc.env.rpcEnv + runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) @@ -283,8 +280,7 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityMgr)._1 + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, 0, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -431,7 +427,7 @@ private[spark] class ApplicationMaster( sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - runAMActor(driverHost, driverPort.toString, isClusterMode = false) + runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -443,7 +439,7 @@ private[spark] class ApplicationMaster( System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) + amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase)) } } @@ -505,44 +501,29 @@ private[spark] class ApplicationMaster( } /** - * An actor that communicates with the driver's scheduler backend. + * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor { - var driver: ActorSelection = _ - - override def preStart(): Unit = { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - driver ! RegisterClusterManager - // In cluster mode, the AM can directly monitor the driver status instead - // of trying to deduce it from the lifecycle of the driver's actor - if (!isClusterMode) { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private class AMEndpoint( + override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) + extends RpcEndpoint with Logging { + + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) } override def receive: PartialFunction[Any, Unit] = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - // In cluster mode, do not rely on the disassociated event to exit - // This avoids potentially reporting incorrect exit codes if the driver fails - if (!isClusterMode) { - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - } - case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") - driver ! x + driver.send(x) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { case Some(a) => a.requestTotalExecutors(requestedTotal) case None => logWarning("Container allocator is not ready to request executors yet.") } - sender ! true + context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -550,7 +531,16 @@ private[spark] class ApplicationMaster( case Some(a) => executorIds.foreach(a.killExecutor) case None => logWarning("Container allocator is not ready to kill executors yet.") } - sender ! true + context.reply(true) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isClusterMode) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index c98763e15b58f..b8f42dadcb464 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -112,7 +112,7 @@ private[yarn] class YarnAllocator( SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) From acffc43455d7b3e4000be4ff0175b8ea19cd280b Mon Sep 17 00:00:00 2001 From: lewuathe Date: Sun, 5 Apr 2015 16:13:31 -0700 Subject: [PATCH 151/171] [SPARK-6262][MLLIB]Implement missing methods for MultivariateStatisticalSummary Add below methods in pyspark for MultivariateStatisticalSummary - normL1 - normL2 Author: lewuathe Closes #5359 from Lewuathe/SPARK-6262 and squashes the following commits: cbe439e [lewuathe] Implement missing methods for MultivariateStatisticalSummary --- python/pyspark/mllib/stat/_statistics.py | 6 ++++++ python/pyspark/mllib/tests.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 218ac148ca992..1d83e9d483f8e 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -49,6 +49,12 @@ def max(self): def min(self): return self.call("min").toArray() + def normL1(self): + return self.call("normL1").toArray() + + def normL2(self): + return self.call("normL2").toArray() + class Statistics(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index dd3b66ce67457..47dad7d12e4e4 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -357,6 +357,12 @@ def test_col_with_different_rdds(self): summary = Statistics.colStats(data) self.assertEqual(10, summary.count()) + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + class VectorUDTTests(PySparkTestCase): From 0b5d028a93b7d5adb148fbf3a576257bb3a6d8cb Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 5 Apr 2015 21:57:15 -0700 Subject: [PATCH 152/171] [SPARK-6602][Core] Update MapOutputTrackerMasterActor to MapOutputTrackerMasterEndpoint This is the second PR for [SPARK-6602]. It updated MapOutputTrackerMasterActor and its unit tests. cc rxin Author: zsxwing Closes #5371 from zsxwing/rpc-rewrite-part2 and squashes the following commits: fcf3816 [zsxwing] Fix the code style 4013a22 [zsxwing] Add doc for uncaught exceptions in RpcEnv 93c6c20 [zsxwing] Add an example of UnserializableException and add ErrorMonitor to monitor errors from Akka 134fe7b [zsxwing] Update MapOutputTrackerMasterActor to MapOutputTrackerMasterEndpoint --- .../org/apache/spark/MapOutputTracker.scala | 61 +++--- .../scala/org/apache/spark/SparkEnv.scala | 18 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 4 +- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 19 +- .../apache/spark/MapOutputTrackerSuite.scala | 100 +++++---- .../org/apache/spark/rpc/RpcEnvSuite.scala | 33 ++- .../apache/spark/util/AkkaUtilsSuite.scala | 198 ++++++++---------- 7 files changed, 221 insertions(+), 212 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5718951451afc..d65c94e410662 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,13 +21,11 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} -import scala.concurrent.Await +import scala.collection.mutable.{HashSet, Map} import scala.collection.JavaConversions._ +import scala.reflect.ClassTag -import akka.actor._ -import akka.pattern.ask - +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.BlockManagerId @@ -38,14 +36,15 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -/** Actor class for MapOutputTrackerMaster */ -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +/** RpcEndpoint class for MapOutputTrackerMaster */ +private[spark] class MapOutputTrackerMasterEndpoint( + override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) + extends RpcEndpoint with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = sender.path.address.hostPort + val hostPort = context.sender.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size @@ -53,19 +52,19 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." - /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception. - * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239) - * will ultimately remove this entire code path. */ + /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. + * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ val exception = new SparkException(msg) logError(msg, exception) - throw exception + context.sendFailure(exception) + } else { + context.reply(mapOutputStatuses) } - sender ! mapOutputStatuses case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) + logInfo("MapOutputTrackerMasterEndpoint stopped!") + context.reply(true) + stop() } } @@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and executor) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - /** Set to the MapOutputTrackerActor living on the driver. */ - var trackerActor: ActorRef = _ + /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ + var trackerEndpoint: RpcEndpointRef = _ /** * This HashMap has different behavior for the driver and the executors. @@ -105,12 +101,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private val fetching = new HashSet[Int] /** - * Send a message to the trackerActor and get its result within a default timeout, or + * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. */ - protected def askTracker(message: Any): Any = { + protected def askTracker[T: ClassTag](message: Any): T = { try { - AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) + trackerEndpoint.askWithReply[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) @@ -118,9 +114,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */ protected def sendTracker(message: Any) { - val response = askTracker(message) + val response = askTracker[Boolean](message) if (response != true) { throw new SparkException( "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) @@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) @@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def stop() { sendTracker(StopMapOutputTracker) mapStatuses.clear() - trackerActor = null + trackerEndpoint = null metadataCleaner.cancel() cachedSerializedStatuses.clear() } @@ -350,6 +345,8 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] object MapOutputTracker extends Logging { + val ENDPOINT_NAME = "MapOutputTracker" + // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 55be0a59fedd9..0171488e09562 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties -import akka.actor._ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -41,7 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * :: DeveloperApi :: @@ -286,15 +285,6 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { - if (isDriver) { - logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) - } else { - AkkaUtils.makeDriverRef(name, conf, actorSystem) - } - } - def registerOrLookupEndpoint( name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { @@ -314,9 +304,9 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index d47e41abcfa50..e259867c14040 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -30,7 +30,9 @@ import org.apache.spark.util.{AkkaUtils, Utils} /** * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote - * nodes, and deliver them to corresponding [[RpcEndpoint]]s. + * nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by + * [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the + * sender, or logging them if no such sender or `NotSerializableException`. * * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 9e06147dff1ed..652e52f2b2e73 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,16 +17,16 @@ package org.apache.spark.rpc.akka -import java.net.URI import java.util.concurrent.ConcurrentHashMap -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import org.apache.spark.{SparkException, Logging, SparkConf} @@ -242,10 +242,25 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { def create(config: RpcEnvConfig): RpcEnv = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) + actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") new AkkaRpcEnv(actorSystem, config.conf, boundPort) } } +/** + * Monitor errors reported by Akka and log them. + */ +private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { + + override def preStart(): Unit = { + context.system.eventStream.subscribe(self, classOf[Error]) + } + + override def receiveWithLogging: Actor.Receive = { + case Error(cause: Throwable, _, _, message: String) => logError(message, cause) + } +} + private[akka] class AkkaRpcEndpointRef( @transient defaultAddress: RpcAddress, @transient _actorRef: => ActorRef, diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ccfe0678cb1c3..6295d34be5ca9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,34 +17,37 @@ package org.apache.spark -import scala.concurrent.Await - -import akka.actor._ -import akka.testkit.TestActorRef +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, isA} import org.scalatest.FunSuite +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf + def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, + securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = { + RpcEnv.create(name, host, port, conf, securityManager) + } + test("master start and stop") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -57,13 +60,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register and unregister shuffle") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -78,14 +82,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(tracker.getServerStatuses(10, 0).isEmpty) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and unregister map output and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -104,25 +108,21 @@ class MapOutputTrackerSuite extends FunSuite { intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -147,8 +147,8 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.stop() slaveTracker.stop() - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch below akka frame size") { @@ -157,19 +157,24 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("spark") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - masterActor.receive(GetMapOutputStatuses(10)) + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + verify(rpcCallContext).reply(any()) + verify(rpcCallContext, never()).sendFailure(any()) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch exceeds akka frame size") { @@ -178,12 +183,11 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("test") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) @@ -191,9 +195,15 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + verify(rpcCallContext, never()).reply(any()) + verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 4f19c4f2110d2..5a734ec5ba5ec 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -514,10 +514,35 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { ("onDisconnected", remoteAddress))) } } -} -case object Start + test("sendWithReply: unserializable error") { + env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { + override val rpcEnv = env -case class Ping(id: Int) + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => context.sendFailure(new UnserializableException) + } + }) -case class Pong(id: Int) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-unserializable-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + intercept[TimeoutException] { + Await.result(f, 1 seconds) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + +} + +class UnserializableClass + +class UnserializableException extends Exception { + private val unserializableField = new UnserializableClass +} diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6250d50fb7036..bec79fc4dc8f7 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -19,14 +19,11 @@ package org.apache.spark.util import java.util.concurrent.TimeoutException -import scala.concurrent.Await -import scala.util.{Failure, Try} - -import akka.actor._ - +import akka.actor.ActorNotFound import org.scalatest.FunSuite import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId import org.apache.spark.SSLSampleConfigs._ @@ -39,39 +36,37 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro test("remote fetch security bad password") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "true") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = conf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off") { @@ -81,28 +76,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(badconf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -120,8 +111,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security pass") { @@ -131,15 +122,14 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") @@ -148,13 +138,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerGood.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = goodconf, securityManager = securityManagerGood) + val slaveRpcEnv =RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -170,47 +157,45 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off client") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === false) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch ssl on") { @@ -218,26 +203,22 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -255,8 +236,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } @@ -267,28 +248,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() slaveConf.set("spark.authenticate", "true") slaveConf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === true) @@ -305,45 +282,43 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch ssl on and security enabled - bad credentials") { val conf = sparkSSLConfig() + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + slaveConf.set("spark.rpc", "akka") slaveConf.set("spark.authenticate", "true") slaveConf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } @@ -352,35 +327,30 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - val result = Try(Await.result(selection.resolveOne(timeout * 2), timeout)) - - result match { - case Failure(ex: ActorNotFound) => - case Failure(ex: TimeoutException) => - case r => fail(s"$r is neither Failure(ActorNotFound) nor Failure(TimeoutException)") + try { + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + fail("should receive either ActorNotFound or TimeoutException") + } catch { + case e: ActorNotFound => + case e: TimeoutException => } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } } From 49f38824a4770fc9017e6cc9b1803c4543b0c081 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Mon, 6 Apr 2015 10:11:20 +0100 Subject: [PATCH 153/171] [SPARK-6673] spark-shell.cmd can't start in Windows even when spark was built added equivalent script to load-spark-env.sh Author: Masayoshi TSUZUKI Closes #5328 from tsudukim/feature/SPARK-6673 and squashes the following commits: aaefb19 [Masayoshi TSUZUKI] removed dust. be3405e [Masayoshi TSUZUKI] [SPARK-6673] spark-shell.cmd can't start in Windows even when spark was built --- bin/load-spark-env.cmd | 59 ++++++++++++++++++++++++++++++++++++++++++ bin/pyspark2.cmd | 3 +-- bin/run-example2.cmd | 3 +-- bin/spark-class2.cmd | 3 +-- 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 bin/load-spark-env.cmd diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd new file mode 100644 index 0000000000000..36d932c453b6f --- /dev/null +++ b/bin/load-spark-env.cmd @@ -0,0 +1,59 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once. +rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's +rem conf/ subdirectory. + +if [%SPARK_ENV_LOADED%] == [] ( + set SPARK_ENV_LOADED=1 + + if not [%SPARK_CONF_DIR%] == [] ( + set user_conf_dir=%SPARK_CONF_DIR% + ) else ( + set user_conf_dir=%~dp0..\..\conf + ) + + call :LoadSparkEnv +) + +rem Setting SPARK_SCALA_VERSION if not already set. + +set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11 +set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10 + +if [%SPARK_SCALA_VERSION%] == [] ( + + if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( + echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." + echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." + exit 1 + ) + if exist %ASSEMBLY_DIR2% ( + set SPARK_SCALA_VERSION=2.11 + ) else ( + set SPARK_SCALA_VERSION=2.10 + ) +) +exit /b 0 + +:LoadSparkEnv +if exist "%user_conf_dir%\spark-env.cmd" ( + call "%user_conf_dir%\spark-env.cmd" +) diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 4f5eb5e20614d..09b4149c2a439 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -20,8 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index b49d0dcb4ff2d..c3e0221fb62e3 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -25,8 +25,7 @@ set FWDIR=%~dp0..\ rem Export this as SPARK_HOME set SPARK_HOME=%FWDIR% -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if not "x%1"=="x" goto arg_given diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 4ce727bc99128..4b3401d745f2a 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -20,8 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if "x%1"=="x" ( From 9fe41252198df71f4629843d363db8c83f36440c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 6 Apr 2015 10:18:56 +0100 Subject: [PATCH 154/171] SPARK-6569 [STREAMING] Down-grade same-offset message in Kafka streaming to INFO Reduce "is the same as ending offset" message to INFO level per JIRA discussion Author: Sean Owen Closes #5366 from srowen/SPARK-6569 and squashes the following commits: 8a5b992 [Sean Owen] Reduce "is the same as ending offset" message to INFO level per JIRA discussion --- .../main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 4a83b715fa89d..a0b8a0c565210 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -86,7 +86,7 @@ class KafkaRDD[ val part = thePart.asInstanceOf[KafkaRDDPartition] assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { - log.warn(s"Beginning offset ${part.fromOffset} is the same as ending offset " + + log.info(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { From 30363ede8635f2548e444697dbcf60a795b61a84 Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Mon, 6 Apr 2015 13:15:01 -0700 Subject: [PATCH 155/171] [MLlib] [SPARK-6713] Iterators in columnSimilarities for mapPartitionsWithIndex Use Iterators in columnSimilarities to allow mapPartitionsWithIndex to spill to disk. This could happen in a dense and large column - this way Spark can spill the pairs onto disk instead of building all the pairs before handing them to Spark. Another PR coming to update documentation. Author: Reza Zadeh Closes #5364 from rezazadeh/optmemsim and squashes the following commits: 47c90ba [Reza Zadeh] Iterators in columnSimilarities for flatMap --- .../mllib/linalg/distributed/RowMatrix.scala | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 961111507f2c2..9a89a6f3a515f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -531,7 +531,6 @@ class RowMatrix( val rand = new XORShiftRandom(indx) val scaled = new Array[Double](p.size) iter.flatMap { row => - val buf = new ListBuffer[((Int, Int), Double)]() row match { case SparseVector(size, indices, values) => val nnz = indices.size @@ -540,8 +539,9 @@ class RowMatrix( scaled(k) = values(k) / q(indices(k)) k += 1 } - k = 0 - while (k < nnz) { + + Iterator.tabulate (nnz) { k => + val buf = new ListBuffer[((Int, Int), Double)]() val i = indices(k) val iVal = scaled(k) if (iVal != 0 && rand.nextDouble() < p(i)) { @@ -555,8 +555,8 @@ class RowMatrix( l += 1 } } - k += 1 - } + buf + }.flatten case DenseVector(values) => val n = values.size var i = 0 @@ -564,8 +564,8 @@ class RowMatrix( scaled(i) = values(i) / q(i) i += 1 } - i = 0 - while (i < n) { + Iterator.tabulate (n) { i => + val buf = new ListBuffer[((Int, Int), Double)]() val iVal = scaled(i) if (iVal != 0 && rand.nextDouble() < p(i)) { var j = i + 1 @@ -577,10 +577,9 @@ class RowMatrix( j += 1 } } - i += 1 - } + buf + }.flatten } - buf } }.reduceByKey(_ + _).map { case ((i, j), sim) => MatrixEntry(i.toLong, j.toLong, sim) From e40ea8742a8771ecd46b182f45b5fcd8bd6dd725 Mon Sep 17 00:00:00 2001 From: Volodymyr Lyubinets Date: Mon, 6 Apr 2015 18:00:51 -0700 Subject: [PATCH 156/171] [Minor] [SQL] [SPARK-6729] Minor fix for DriverQuirks get The function uses .substring(0, X), which will trigger OutOfBoundsException if string length is less than X. A better way to do this is to use startsWith, which won't error out in this case. Author: Volodymyr Lyubinets Closes #5378 from vlyubin/quirks and squashes the following commits: 504e8e0 [Volodymyr Lyubinets] Minor fix for DriverQuirks get --- .../main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala index 1704be7fcbd30..0feabc4282f4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala @@ -49,9 +49,9 @@ private[sql] object DriverQuirks { * Fetch the DriverQuirks class corresponding to a given database url. */ def get(url: String): DriverQuirks = { - if (url.substring(0, 10).equals("jdbc:mysql")) { + if (url.startsWith("jdbc:mysql")) { new MySQLQuirks() - } else if (url.substring(0, 15).equals("jdbc:postgresql")) { + } else if (url.startsWith("jdbc:postgresql")) { new PostgresQuirks() } else { new NoQuirks() From a0846c4b635eac8d8637c83d490177f881952d27 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Apr 2015 23:33:16 -0700 Subject: [PATCH 157/171] [SPARK-6716] Change SparkContext.DRIVER_IDENTIFIER from to driver Currently, the driver's executorId is set to ``. This choice of ID was present in older Spark versions, but it has started to cause problems now that executorIds are used in more contexts, such as Ganglia metric names or driver thread-dump links the web UI. The angle brackets must be escaped when embedding this ID in XML or as part of URLs and this has led to multiple problems: - https://issues.apache.org/jira/browse/SPARK-6484 - https://issues.apache.org/jira/browse/SPARK-4313 The simplest solution seems to be to change this id to something that does not contain any special characters, such as `driver`. I'm not sure whether we can perform this change in a patch release, since this ID may be considered a stable API by metrics users, but it's probably okay to do this in a major release as long as we document it in the release notes. Author: Josh Rosen Closes #5372 from JoshRosen/driver-id-fix and squashes the following commits: 42d3c10 [Josh Rosen] Clarify comment 0c5d04b [Josh Rosen] Add backwards-compatibility in BlockManagerId.isDriver 7ff12e0 [Josh Rosen] Change SparkContext.DRIVER_IDENTIFIER from to driver --- .../main/scala/org/apache/spark/SparkContext.scala | 12 +++++++++++- .../org/apache/spark/storage/BlockManagerId.scala | 5 ++++- .../org/apache/spark/storage/BlockManagerSuite.scala | 6 ++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 942c5975ece6d..3f1a7dd99d635 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1901,7 +1901,17 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" - private[spark] val DRIVER_IDENTIFIER = "" + /** + * Executor id for the driver. In earlier versions of Spark, this was ``, but this was + * changed to `driver` because the angle brackets caused escaping issues in URLs and XML (see + * SPARK-6716 for more details). + */ + private[spark] val DRIVER_IDENTIFIER = "driver" + + /** + * Legacy version of DRIVER_IDENTIFIER, retained for backwards-compatibility. + */ + private[spark] val LEGACY_DRIVER_IDENTIFIER = "" // The following deprecated objects have already been copied to `object AccumulatorParam` to // make the compiler find them automatically. They are duplicate codes only for backward diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index a6f1ebf325a7c..69ac37511e730 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -60,7 +60,10 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } + def isDriver: Boolean = { + executorId == SparkContext.DRIVER_IDENTIFIER || + executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 283090e3bdb1f..6dc5bc4cb08c4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -139,6 +139,12 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } + test("BlockManagerId.isDriver() backwards-compatibility with legacy driver ids (SPARK-6716)") { + assert(BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(BlockManagerId(SparkContext.LEGACY_DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(!BlockManagerId("notADriverIdentifier", "XXX", 1).isDriver) + } + test("master + 1 manager interaction") { store = makeBlockManager(20000) val a1 = new Array[Byte](4000) From 6f0d55d76f758d217fd18ffa0ccf273d7ab0377b Mon Sep 17 00:00:00 2001 From: Matt Aasted Date: Mon, 6 Apr 2015 23:50:48 -0700 Subject: [PATCH 158/171] [SPARK-6636] Use public DNS hostname everywhere in spark_ec2.py The spark_ec2.py script uses public_dns_name everywhere in the script except for testing ssh availability, which is done using the public ip address of the instances. This breaks the script for users who are deploying the cluster with a private-network-only security group. The fix is to use public_dns_name in the remaining place. Author: Matt Aasted Closes #5302 from aasted/master and squashes the following commits: 60cf6ee [Matt Aasted] [SPARK-6636] Use public DNS hostname everywhere in spark_ec2.py --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 5507a9c5a4733..879a52cef8ff0 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -809,7 +809,7 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): + if not is_ssh_available(host=i.public_dns_name, opts=opts): return False else: return True From ae980eb41c00b5f1f64c650f267b884e864693f0 Mon Sep 17 00:00:00 2001 From: Sasaki Toru Date: Tue, 7 Apr 2015 01:55:32 -0700 Subject: [PATCH 159/171] [SPARK-6736][GraphX][Doc]Example of Graph#aggregateMessages has error Example of Graph#aggregateMessages has error. Since aggregateMessages is a method of Graph, It should be written "rawGraph.aggregateMessages" Author: Sasaki Toru Closes #5388 from sasakitoa/aggregateMessagesExample and squashes the following commits: b1d631b [Sasaki Toru] Example of Graph#aggregateMessages has error --- graphx/src/main/scala/org/apache/spark/graphx/Graph.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 8494d06b1cdb7..36dc7b0f86c89 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -409,7 +409,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * {{{ * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") * val inDeg: RDD[(VertexId, Int)] = - * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * rawGraph.aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) * }}} * * @note By expressing computation at the edge level we achieve From b65bad65c3500475b974ca0219f218eef296db2c Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Tue, 7 Apr 2015 08:36:25 -0500 Subject: [PATCH 160/171] [SPARK-3591][YARN]fire and forget for YARN cluster mode https://issues.apache.org/jira/browse/SPARK-3591 The output after this patch: >doggie153:/opt/oss/spark-1.3.0-bin-hadoop2.4/bin # ./spark-submit --class org.apache.spark.examples.SparkPi --master yarn-cluster ../lib/spark-examples*.jar 15/03/31 21:15:25 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 15/03/31 21:15:25 INFO RMProxy: Connecting to ResourceManager at doggie153/10.177.112.153:8032 15/03/31 21:15:25 INFO Client: Requesting a new application from cluster with 4 NodeManagers 15/03/31 21:15:25 INFO Client: Verifying our application has not requested more than the maximum memory capability of the cluster (8192 MB per container) 15/03/31 21:15:25 INFO Client: Will allocate AM container, with 896 MB memory including 384 MB overhead 15/03/31 21:15:25 INFO Client: Setting up container launch context for our AM 15/03/31 21:15:25 INFO Client: Preparing resources for our AM container 15/03/31 21:15:26 INFO Client: Uploading resource file:/opt/oss/spark-1.3.0-bin-hadoop2.4/lib/spark-assembly-1.4.0-SNAPSHOT-hadoop2.4.1.jar -> hdfs://doggie153:9000/user/root/.sparkStaging/application_1427257505534_0016/spark-assembly-1.4.0-SNAPSHOT-hadoop2.4.1.jar 15/03/31 21:15:27 INFO Client: Uploading resource file:/opt/oss/spark-1.3.0-bin-hadoop2.4/lib/spark-examples-1.3.0-hadoop2.4.0.jar -> hdfs://doggie153:9000/user/root/.sparkStaging/application_1427257505534_0016/spark-examples-1.3.0-hadoop2.4.0.jar 15/03/31 21:15:28 INFO Client: Setting up the launch environment for our AM container 15/03/31 21:15:28 INFO SecurityManager: Changing view acls to: root 15/03/31 21:15:28 INFO SecurityManager: Changing modify acls to: root 15/03/31 21:15:28 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(root); users with modify permissions: Set(root) 15/03/31 21:15:28 INFO Client: Submitting application 16 to ResourceManager 15/03/31 21:15:28 INFO YarnClientImpl: Submitted application application_1427257505534_0016 15/03/31 21:15:28 INFO Client: ... waiting before polling ResourceManager for application state 15/03/31 21:15:33 INFO Client: ... polling ResourceManager for application state 15/03/31 21:15:33 INFO Client: Application report for application_1427257505534_0016 (state: RUNNING) 15/03/31 21:15:33 INFO Client: client token: N/A diagnostics: N/A ApplicationMaster host: doggie157 ApplicationMaster RPC port: 0 queue: default start time: 1427807728307 final status: UNDEFINED tracking URL: http://doggie153:8088/proxy/application_1427257505534_0016/ user: root /cc andrewor14 Author: WangTaoTheTonic Closes #5297 from WangTaoTheTonic/SPARK-3591 and squashes the following commits: c76d232 [WangTaoTheTonic] wrap lines 16c90a8 [WangTaoTheTonic] move up lines to avoid duplicate fea390d [WangTaoTheTonic] log failed/killed report, style and comment be1cc2e [WangTaoTheTonic] reword f0bc54f [WangTaoTheTonic] minor: expose appid in excepiton messages ba9b22b [WangTaoTheTonic] wrong config name e1a4013 [WangTaoTheTonic] revert to the old version and do some robust 19706c0 [WangTaoTheTonic] add a config to control whether to forget 0cbdce8 [WangTaoTheTonic] fire and forget for YARN cluster mode --- .../org/apache/spark/deploy/Client.scala | 2 +- .../deploy/rest/StandaloneRestClient.scala | 2 +- docs/running-on-yarn.md | 9 ++ .../org/apache/spark/deploy/yarn/Client.scala | 83 +++++++++++-------- 4 files changed, 61 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 65238af2caa24..8d13b2a2cd4f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -89,7 +89,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println(s"... waiting before polling master for driver state") + println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index a3539e44bd2f9..b8fd406fb6f9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -245,7 +245,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } else { val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") - logError("Application submission failed" + failMessage) + logError(s"Application submission failed$failMessage") } } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d9f3eb2b74b18..b7e68d4f71714 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -196,6 +196,15 @@ Most of the configs are the same for Spark on YARN as for other deployment modes It should be no larger than the global number of max attempts in the YARN configuration. + + spark.yarn.submit.waitAppCompletion + true + + In YARN cluster mode, controls whether the client waits to exit until the application completes. + If set to true, the client process will stay alive reporting the application's status. + Otherwise, the client process will exit after submission. + + # Launching Spark on YARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 61f8fc3f5a014..79d55a09eb671 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -66,6 +66,8 @@ private[spark] class Client( private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() private val isClusterMode = args.isClusterMode + private val fireAndForget = isClusterMode && + !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) def stop(): Unit = yarnClient.stop() @@ -564,31 +566,13 @@ private[spark] class Client( if (logApplicationReport) { logInfo(s"Application report for $appId (state: $state)") - val details = Seq[(String, String)]( - ("client token", getClientToken(report)), - ("diagnostics", report.getDiagnostics), - ("ApplicationMaster host", report.getHost), - ("ApplicationMaster RPC port", report.getRpcPort.toString), - ("queue", report.getQueue), - ("start time", report.getStartTime.toString), - ("final status", report.getFinalApplicationStatus.toString), - ("tracking URL", report.getTrackingUrl), - ("user", report.getUser) - ) - - // Use more loggable format if value is null or empty - val formattedDetails = details - .map { case (k, v) => - val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") - s"\n\t $k: $newValue" } - .mkString("") // If DEBUG is enabled, log report details every iteration // Otherwise, log them every time the application changes state if (log.isDebugEnabled) { - logDebug(formattedDetails) + logDebug(formatReportDetails(report)) } else if (lastState != state) { - logInfo(formattedDetails) + logInfo(formatReportDetails(report)) } } @@ -609,24 +593,57 @@ private[spark] class Client( throw new SparkException("While loop is depleted! This should never happen...") } + private def formatReportDetails(report: ApplicationReport): String = { + val details = Seq[(String, String)]( + ("client token", getClientToken(report)), + ("diagnostics", report.getDiagnostics), + ("ApplicationMaster host", report.getHost), + ("ApplicationMaster RPC port", report.getRpcPort.toString), + ("queue", report.getQueue), + ("start time", report.getStartTime.toString), + ("final status", report.getFinalApplicationStatus.toString), + ("tracking URL", report.getTrackingUrl), + ("user", report.getUser) + ) + + // Use more loggable format if value is null or empty + details.map { case (k, v) => + val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") + s"\n\t $k: $newValue" + }.mkString("") + } + /** - * Submit an application to the ResourceManager and monitor its state. - * This continues until the application has exited for any reason. + * Submit an application to the ResourceManager. + * If set spark.yarn.submit.waitAppCompletion to true, it will stay alive + * reporting the application's status until the application has exited for any reason. + * Otherwise, the client process will exit after submission. * If the application finishes with a failed, killed, or undefined status, * throw an appropriate SparkException. */ def run(): Unit = { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(submitApplication()) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { - throw new SparkException("Application finished with failed status") - } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { - throw new SparkException("Application is killed") - } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { - throw new SparkException("The final status of application is undefined") + val appId = submitApplication() + if (fireAndForget) { + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + logInfo(s"Application report for $appId (state: $state)") + logInfo(formatReportDetails(report)) + if (state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + throw new SparkException(s"Application $appId finished with status: $state") + } + } else { + val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) + if (yarnApplicationState == YarnApplicationState.FAILED || + finalApplicationStatus == FinalApplicationStatus.FAILED) { + throw new SparkException(s"Application $appId finished with failed status") + } + if (yarnApplicationState == YarnApplicationState.KILLED || + finalApplicationStatus == FinalApplicationStatus.KILLED) { + throw new SparkException(s"Application $appId is killed") + } + if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + throw new SparkException(s"The final status of application $appId is undefined") + } } } } From 7162ecf88624615c78a332de482f5defd297e415 Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Tue, 7 Apr 2015 10:42:08 -0700 Subject: [PATCH 161/171] [SPARK-6733][ Scheduler]Added scala.language.existentials Author: Vinod K C Closes #5384 from vinodkc/Suppression_Scala_existential_code and squashes the following commits: 82a3a1f [Vinod K C] Added scala.language.existentials --- .../src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 1 + .../test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 917cce1f9686c..c82ae4baa3630 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} import scala.concurrent.duration._ +import scala.language.existentials import scala.language.postfixOps import scala.util.control.NonFatal diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 29d4ec5f85c1e..fc7349330cf86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.scalatest.FunSuite From 2c32bef1790dac6f77ef9674f6106c2e24ea0338 Mon Sep 17 00:00:00 2001 From: sksamuel Date: Tue, 7 Apr 2015 10:43:22 -0700 Subject: [PATCH 162/171] Replace use of .size with .length for Arrays Invoking .size on arrays is valid, but requires an implicit conversion to SeqLike. This incurs a compile time overhead and more importantly a runtime overhead, as the Array must be wrapped before the method can be invoked. For example, the difference in generated byte code is: public int withSize(); Code: 0: getstatic #23 // Field scala/Predef$.MODULE$:Lscala/Predef$; 3: aload_0 4: invokevirtual #25 // Method array:()[I 7: invokevirtual #29 // Method scala/Predef$.intArrayOps:([I)Lscala/collection/mutable/ArrayOps; 10: invokeinterface #34, 1 // InterfaceMethod scala/collection/mutable/ArrayOps.size:()I 15: ireturn public int withLength(); Code: 0: aload_0 1: invokevirtual #25 // Method array:()[I 4: arraylength 5: ireturn Author: sksamuel Closes #5376 from sksamuel/master and squashes the following commits: 77ec261 [sksamuel] Replace use of .size with .length for Arrays. --- .../apache/spark/network/nio/Connection.scala | 2 +- .../apache/spark/rdd/AsyncRDDActions.scala | 10 ++++----- .../scala/org/apache/spark/rdd/BlockRDD.scala | 2 +- .../org/apache/spark/rdd/CartesianRDD.scala | 4 ++-- .../org/apache/spark/rdd/CheckpointRDD.scala | 2 +- .../org/apache/spark/rdd/CoGroupedRDD.scala | 4 ++-- .../org/apache/spark/rdd/CoalescedRDD.scala | 2 +- .../apache/spark/rdd/DoubleRDDFunctions.scala | 4 ++-- .../spark/rdd/OrderedRDDFunctions.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 22 +++++++++---------- .../apache/spark/rdd/RDDCheckpointData.scala | 6 ++--- .../org/apache/spark/rdd/SubtractedRDD.scala | 2 +- .../scala/org/apache/spark/rdd/UnionRDD.scala | 6 ++--- .../spark/rdd/ZippedPartitionsRDD.scala | 4 ++-- .../apache/spark/rdd/ZippedWithIndexRDD.scala | 2 +- .../org/apache/spark/storage/RDDInfo.scala | 2 +- .../apache/spark/ui/ConsoleProgressBar.scala | 4 ++-- .../apache/spark/util/collection/BitSet.scala | 2 +- 19 files changed, 42 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 04eb2bf9ba4ab..6b898bd4bfc1b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -181,7 +181,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, buffer.get(bytes) bytes.foreach(x => print(x + " ")) buffer.position(curPosition) - print(" (" + bytes.size + ")") + print(" (" + bytes.length + ")") } def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 646df283ac069..3406a7e97e368 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -45,7 +45,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } result }, - Range(0, self.partitions.size), + Range(0, self.partitions.length), (index: Int, data: Long) => totalCount.addAndGet(data), totalCount.get()) } @@ -54,8 +54,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + val results = new Array[Array[T]](self.partitions.length) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), (index, data) => results(index) = data, results.flatten.toSeq) } @@ -111,7 +111,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def foreachAsync(f: T => Unit): FutureAction[Unit] = { val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), (index, data) => Unit, Unit) } @@ -119,7 +119,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), (index, data) => Unit, Unit) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fffa1911f5bc2..71578d1210fde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -36,7 +36,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.size).map(i => { + (0 until blockIds.length).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] }).toArray } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9059eb13bb5d8..c1d6971787572 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -53,11 +53,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( extends RDD[Pair[T, U]](sc, Nil) with Serializable { - val numPartitionsInRdd2 = rdd2.partitions.size + val numPartitionsInRdd2 = rdd2.partitions.length override def getPartitions: Array[Partition] = { // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length) for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { val idx = s1.index * numPartitionsInRdd2 + s2.index array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 760c0fa3ac96a..0d130dd4c7a60 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -49,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.size + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 07398a6fa62f6..7021a339e879b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -99,7 +99,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will have a dependency per contributing RDD array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it @@ -120,7 +120,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size + val numRdds = split.deps.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 5117ccfabfcc2..0c1b02c07d09f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -166,7 +166,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // determines the tradeoff between load-balancing the partitions sizes and their locality // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt + val slack = (balanceSlack * prev.partitions.length).toInt var noLocality = true // if true if no preferredLocations exists for parent RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 71e6e300fec5f..29ca3e9c4bd04 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -70,7 +70,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) + val evaluator = new MeanEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -81,7 +81,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) + val evaluator = new SumEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 6fdfdb734d1b8..6afe50161dacd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -56,7 +56,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * order of the keys). */ // TODO: this currently doesn't work on P other than Tuple2! - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length) : RDD[(K, V)] = { val part = new RangePartitioner(numPartitions, self, ascending) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index bf1303d39592d..05351ba4ff76b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -823,7 +823,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * RDD will be <= us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ddbfd5624e741..d80d94a588346 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -316,7 +316,7 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(): RDD[T] = distinct(partitions.size) + def distinct(): RDD[T] = distinct(partitions.length) /** * Return a new RDD that has exactly numPartitions partitions. @@ -488,7 +488,7 @@ abstract class RDD[T: ClassTag]( def sortBy[K]( f: (T) => K, ascending: Boolean = true, - numPartitions: Int = this.partitions.size) + numPartitions: Int = this.partitions.length) (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = this.keyBy[K](f) .sortByKey(ascending, numPartitions) @@ -852,7 +852,7 @@ abstract class RDD[T: ClassTag]( * RDD will be <= us. */ def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.length))) /** * Return an RDD with the elements from `this` that are not in `other`. @@ -986,14 +986,14 @@ abstract class RDD[T: ClassTag]( combOp: (U, U) => U, depth: Int = 2): U = { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (partitions.size == 0) { + if (partitions.length == 0) { return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) } val cleanSeqOp = context.clean(seqOp) val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size + var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. while (numPartitions > scale + numPartitions / scale) { @@ -1026,7 +1026,7 @@ abstract class RDD[T: ClassTag]( } result } - val evaluator = new CountEvaluator(partitions.size, confidence) + val evaluator = new CountEvaluator(partitions.length, confidence) sc.runApproximateJob(this, countElements, evaluator, timeout) } @@ -1061,7 +1061,7 @@ abstract class RDD[T: ClassTag]( } map } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + val evaluator = new GroupedCountEvaluator[T](partitions.length, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } @@ -1140,7 +1140,7 @@ abstract class RDD[T: ClassTag]( * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ def zipWithUniqueId(): RDD[(T, Long)] = { - val n = this.partitions.size.toLong + val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => iter.zipWithIndex.map { case (item, i) => (item, i * n + k) @@ -1243,7 +1243,7 @@ abstract class RDD[T: ClassTag]( queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) } - if (mapRDDs.partitions.size == 0) { + if (mapRDDs.partitions.length == 0) { Array.empty } else { mapRDDs.reduce { (queue1, queue2) => @@ -1489,7 +1489,7 @@ abstract class RDD[T: ClassTag]( } // The first RDD in the dependency stack has no parents, so no need for a +- def firstDebugString(rdd: RDD[_]): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) @@ -1499,7 +1499,7 @@ abstract class RDD[T: ClassTag]( } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val thisPrefix = prefix.replaceAll("\\|\\s+$", "") val nextPrefix = ( diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f1857979..6afd63d537d75 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -94,10 +94,10 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) new SerializableWritable(rdd.context.hadoopConfiguration)) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (newRDD.partitions.size != rdd.partitions.size) { + if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")") } // Change the dependencies and partitions of the RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index c27f435eb9c5a..e9d745588ee9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -76,7 +76,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 4239e7e22af89..3986645350a82 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -63,7 +63,7 @@ class UnionRDD[T: ClassTag]( extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) + val array = new Array[Partition](rdds.map(_.partitions.length).sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) @@ -76,8 +76,8 @@ class UnionRDD[T: ClassTag]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length) + pos += rdd.partitions.length } deps } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index d0be304762e1f..a96b6c3d23454 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -52,8 +52,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( if (preservesPartitioning) firstParent[Any].partitioner else None override def getPartitions: Array[Partition] = { - val numParts = rdds.head.partitions.size - if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { + val numParts = rdds.head.partitions.length + if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } Array.tabulate[Partition](numParts) { i => diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 8c43a559409f2..523aaf2b860b5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -41,7 +41,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { - val n = prev.partitions.size + val n = prev.partitions.length if (n == 0) { Array[Long]() } else if (n == 1) { diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 0186eb30a1905..034525b56f59c 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -52,6 +52,6 @@ class RDDInfo( private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel) } } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 67f572e79314d..77c0bc8b5360a 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -65,7 +65,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val stageIds = sc.statusTracker.getActiveStageIds() val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) - if (stages.size > 0) { + if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } } @@ -81,7 +81,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val total = s.numTasks() val header = s"[Stage ${s.stageId()}:" val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" - val w = width - header.size - tailer.size + val w = width - header.length - tailer.length val bar = if (w > 0) { val percent = w * s.numCompletedTasks() / total (0 until w).map { i => diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index f79e8e0491ea1..41cb8cfe2afa3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -39,7 +39,7 @@ class BitSet(numBits: Int) extends Serializable { val wordIndex = bitIndex >> 6 // divide by 64 var i = 0 while(i < wordIndex) { words(i) = -1; i += 1 } - if(wordIndex < words.size) { + if(wordIndex < words.length) { // Set the remaining bits (note that the mask could still be zero) val mask = ~(-1L << (bitIndex & 0x3f)) words(wordIndex) |= mask From 12322159147581602978f7f5a6b33b887ef781a1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 Apr 2015 12:37:33 -0700 Subject: [PATCH 163/171] [SPARK-6750] Upgrade ScalaStyle to 0.7. 0.7 fixes a bug that's pretty useful, i.e. inline functions no longer return explicit type definition. Author: Reynold Xin Closes #5399 from rxin/style0.7 and squashes the following commits: 54c41b2 [Reynold Xin] Actually update the version. 09c759c [Reynold Xin] [SPARK-6750] Upgrade ScalaStyle to 0.7. --- project/plugins.sbt | 2 +- project/project/SparkPluginBuild.scala | 16 +------- .../scalastyle/NonASCIICharacterChecker.scala | 39 ------------------- 3 files changed, 2 insertions(+), 55 deletions(-) delete mode 100644 project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala diff --git a/project/plugins.sbt b/project/plugins.sbt index ee45b6a51905e..7096b0d3ee7de 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 8863f272da415..471d00bd8223f 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -24,20 +24,6 @@ import sbt.Keys._ * becomes available for scalastyle sbt plugin. */ object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) - lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") - - // There is actually no need to publish this artifact. - def styleSettings = Defaults.defaultSettings ++ Seq ( - name := "spark-style", - organization := "org.apache.spark", - scalaVersion := "2.10.4", - scalacOptions := Seq("-unchecked", "-deprecation"), - libraryDependencies ++= Dependencies.scalaStyle - ) - - object Dependencies { - val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0") - } } diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala deleted file mode 100644 index 3d43c35299555..0000000000000 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.scalastyle - -import java.util.regex.Pattern - -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} - -import scalariform.lexer.Token -import scalariform.parser.CompilationUnit - -class NonASCIICharacterChecker extends ScalariformChecker { - val errorKey: String = "non.ascii.character.disallowed" - - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList - } - - private def hasNonAsciiChars(x: Token) = - x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) - .matcher(x.text.trim).matches() - -} From 596ba77c5fdca79486396989e549632153055caf Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 7 Apr 2015 14:29:53 -0700 Subject: [PATCH 164/171] [SPARK-6568] spark-shell.cmd --jars option does not accept the jar that has space in its path escape spaces in the arguments. Author: Masayoshi TSUZUKI Closes #5347 from tsudukim/feature/SPARK-6568 and squashes the following commits: 9180aaf [Masayoshi TSUZUKI] [SPARK-6568] spark-shell.cmd --jars option does not accept the jar that has space in its path --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- core/src/test/scala/org/apache/spark/util/UtilsSuite.scala | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0fdfaf300e95d..25ae6ee579ab3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1661,7 +1661,7 @@ private[spark] object Utils extends Logging { /** * Format a Windows path such that it can be safely passed to a URI. */ - def formatWindowsPath(path: String): String = path.replace("\\", "/") + def formatWindowsPath(path: String): String = path.replace("\\", "/").replace(" ", "%20") /** * Indicates whether Spark is currently running unit tests. diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 5d93086082189..b7cc84078983a 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -241,6 +241,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assertResolves("C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) assertResolves("C:\\path\\to\\file.txt", "file:/C:/path/to/file.txt", testWindows = true) assertResolves("file:/C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) + assertResolves("file:/C:/path to/file.txt", "file:/C:/path%20to/file.txt", testWindows = true) assertResolves("file:///C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) assertResolves("file:/C:/file.txt#alias.txt", "file:/C:/file.txt#alias.txt", testWindows = true) intercept[IllegalArgumentException] { Utils.resolveURI("file:foo") } @@ -264,8 +265,9 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assertResolves("hdfs:/jar1,file:/jar2,jar3", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3") assertResolves("hdfs:/jar1,file:/jar2,jar3,jar4#jar5", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4#jar5") - assertResolves("hdfs:/jar1,file:/jar2,jar3,C:\\pi.py#py.pi", - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi", testWindows = true) + assertResolves("""hdfs:/jar1,file:/jar2,jar3,C:\pi.py#py.pi,C:\path to\jar4.jar""", + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi,file:/C:/path%20to/jar4.jar", + testWindows = true) } test("nonLocalPaths") { From e6f08fb42fda35952ea8b005170750ae551dc7d9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 7 Apr 2015 14:34:15 -0700 Subject: [PATCH 165/171] Revert "[SPARK-6568] spark-shell.cmd --jars option does not accept the jar that has space in its path" This reverts commit 596ba77c5fdca79486396989e549632153055caf. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- core/src/test/scala/org/apache/spark/util/UtilsSuite.scala | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 25ae6ee579ab3..0fdfaf300e95d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1661,7 +1661,7 @@ private[spark] object Utils extends Logging { /** * Format a Windows path such that it can be safely passed to a URI. */ - def formatWindowsPath(path: String): String = path.replace("\\", "/").replace(" ", "%20") + def formatWindowsPath(path: String): String = path.replace("\\", "/") /** * Indicates whether Spark is currently running unit tests. diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index b7cc84078983a..5d93086082189 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -241,7 +241,6 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assertResolves("C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) assertResolves("C:\\path\\to\\file.txt", "file:/C:/path/to/file.txt", testWindows = true) assertResolves("file:/C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) - assertResolves("file:/C:/path to/file.txt", "file:/C:/path%20to/file.txt", testWindows = true) assertResolves("file:///C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) assertResolves("file:/C:/file.txt#alias.txt", "file:/C:/file.txt#alias.txt", testWindows = true) intercept[IllegalArgumentException] { Utils.resolveURI("file:foo") } @@ -265,9 +264,8 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assertResolves("hdfs:/jar1,file:/jar2,jar3", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3") assertResolves("hdfs:/jar1,file:/jar2,jar3,jar4#jar5", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4#jar5") - assertResolves("""hdfs:/jar1,file:/jar2,jar3,C:\pi.py#py.pi,C:\path to\jar4.jar""", - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi,file:/C:/path%20to/jar4.jar", - testWindows = true) + assertResolves("hdfs:/jar1,file:/jar2,jar3,C:\\pi.py#py.pi", + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi", testWindows = true) } test("nonLocalPaths") { From fc957dc78138e72036dbbadc9a54f155d318c038 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Tue, 7 Apr 2015 14:36:57 -0700 Subject: [PATCH 166/171] [SPARK-6720][MLLIB] PySpark MultivariateStatisticalSummary unit test for normL1... ... and normL2. Add test cases to insufficient unit test for `normL1` and `normL2`. Ref: https://github.com/apache/spark/pull/5359 Author: lewuathe Closes #5374 from Lewuathe/SPARK-6720 and squashes the following commits: 5541b24 [lewuathe] More accurate tests dc5718c [lewuathe] [SPARK-6720] PySpark MultivariateStatisticalSummary unit test for normL1 and normL2 --- python/pyspark/mllib/tests.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 47dad7d12e4e4..61ef398487c0c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -363,6 +363,13 @@ def test_col_norms(self): self.assertEqual(10, len(summary.normL1())) self.assertEqual(10, len(summary.normL2())) + data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) + class VectorUDTTests(PySparkTestCase): From 77bcceb9f01e97cb6f41791f2167b40c4311f701 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Apr 2015 07:00:56 +0800 Subject: [PATCH 167/171] [SPARK-6748] [SQL] Makes QueryPlan.schema a lazy val `DataFrame.collect()` calls `SparkPlan.executeCollect()`, which consists of a single line: ```scala execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() ``` The problem is that, `QueryPlan.schema` is a function. And since 1.3.0, `convertRowToScala` starts returning a `GenericRowWithSchema`. Thus, every `GenericRowWithSchema` instance holds a separate copy of the schema object. Also, YJP profiling result of the following simple micro benchmark (executed in Spark shell) shows that constructing the schema object takes up to ~35% CPU time. ```scala sc.parallelize(1 to 10000000). map(i => (i, s"val_$i")). toDF("key", "value"). saveAsParquetFile("file:///tmp/src.parquet") // Profiling started from this line sqlContext.parquetFile("file:///tmp/src.parquet").collect() ``` [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5398) Author: Cheng Lian Closes #5398 from liancheng/spark-6748 and squashes the following commits: 3159469 [Cheng Lian] Makes QueryPlan.schema a lazy val --- .../scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 02f7c26a8ab6e..7967189cacb24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -150,7 +150,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - def schema: StructType = StructType.fromAttributes(output) + lazy val schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ def schemaString: String = schema.treeString From c83e03948b184ffb3a9418fecc4d2c26ae33b057 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Apr 2015 16:18:55 -0700 Subject: [PATCH 168/171] [SPARK-6737] Fix memory leak in OutputCommitCoordinator This patch fixes a memory leak in the DAGScheduler, which caused us to leak a map entry per submitted stage. The problem is that the OutputCommitCoordinator needs to be informed when stages end in order to remove entries from its `authorizedCommitters` map, but the DAGScheduler only called it in one of the four code paths that are used to mark stages as completed. This patch fixes this issue by consolidating the processing of stage completion into a new `markStageAsFinished` method and updates DAGSchedulerSuite's `assertDataStructuresEmpty` assertion to also check the OutputCommitCoordinator data structures. I've also added a comment at the top of DAGScheduler so that we remember to update this test when adding new data structures. Author: Josh Rosen Closes #5397 from JoshRosen/SPARK-6737 and squashes the following commits: af3b02f [Josh Rosen] Consolidate stage completion handling code in a single method. e96ce3a [Josh Rosen] Consolidate stage completion handling code in a single method. 3052aea [Josh Rosen] Comment update 7896899 [Josh Rosen] Fix SPARK-6737 by informing OutputCommitCoordinator of all stage end events. 4ead1dc [Josh Rosen] Add regression tests for SPARK-6737 --- .../apache/spark/scheduler/DAGScheduler.scala | 63 ++++++++++--------- .../scheduler/OutputCommitCoordinator.scala | 7 +++ .../spark/scheduler/DAGSchedulerSuite.scala | 1 + 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c82ae4baa3630..c912520fded3b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -50,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. */ private[spark] class DAGScheduler( @@ -111,6 +115,8 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() @@ -128,8 +134,6 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - private val outputCommitCoordinator = env.outputCommitCoordinator - // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -710,9 +714,10 @@ class DAGScheduler( // cancelling the stages because if the DAG scheduler is stopped, the entire application // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" - runningStages.foreach { stage => - stage.latestInfo.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) } listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } @@ -887,10 +892,9 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { - // Because we posted SparkListenerStageSubmitted earlier, we should post - // SparkListenerStageCompleted here in case there are no tasks to run. - outputCommitCoordinator.stageEnd(stage.id) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) val debugString = stage match { case stage: ShuffleMapStage => @@ -902,7 +906,6 @@ class DAGScheduler( s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" } logDebug(debugString) - runningStages -= stage } } @@ -968,22 +971,6 @@ class DAGScheduler( } val stage = stageIdToStage(task.stageId) - - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { - val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) - case _ => "Unknown" - } - if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTimeMillis()) - } else { - stage.latestInfo.stageFailed(errorMessage.get) - logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) - } - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - runningStages -= stage - } event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, @@ -1099,7 +1086,6 @@ class DAGScheduler( logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") markStageAsFinished(failedStage, Some(failureMessage)) - runningStages -= failedStage } if (disallowStageRetryForTest) { @@ -1215,6 +1201,26 @@ class DAGScheduler( submitWaitingStages() } + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + outputCommitCoordinator.stageEnd(stage.id) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. @@ -1264,8 +1270,7 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.latestInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 9e29fd13821dc..7c184b1dcb308 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -59,6 +59,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + /** + * Returns whether the OutputCommitCoordinator's internal data structures are all empty. + */ + def isEmpty: Boolean = { + authorizedCommittersByStage.isEmpty + } + /** * Called by tasks to ask whether they can commit their output to HDFS. * diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 63360a0f189a3..eb759f0807a17 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -783,6 +783,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) + assert(scheduler.outputCommitCoordinator.isEmpty) } // Nothing in this test should break if the task info's fields are null, but From d138aa8ee23f4450242da3ac70a493229a90c76b Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Tue, 7 Apr 2015 23:36:31 -0400 Subject: [PATCH 169/171] [SPARK-6705][MLLIB] Add fit intercept api to ml logisticregression I have the fit intercept enabled by default for logistic regression, I wonder what others think here. I understand that it enables allocation by default which is undesirable, but one needs to have a very strong reason for not having an intercept term enabled so it is the safer default from a statistical sense. Explicitly modeling the intercept by adding a column of all 1s does not work. I believe the reason is that since the API for LogisticRegressionWithLBFGS forces column normalization, and a column of all 1s has 0 variance so dividing by 0 kills it. Author: Omede Firouz Closes #5301 from oefirouz/addIntercept and squashes the following commits: 9f1286b [Omede Firouz] [SPARK-6705][MLLIB] Add fitInterceptTerm to LogisticRegression 1d6bd6f [Omede Firouz] [SPARK-6705][MLLIB] Add a fit intercept term to ML LogisticRegression 9963509 [Omede Firouz] [MLLIB] Add fitIntercept to LogisticRegression 2257fca [Omede Firouz] [MLLIB] Add fitIntercept param to logistic regression 329c1e2 [Omede Firouz] [MLLIB] Add fit intercept term bd9663c [Omede Firouz] [MLLIB] Add fit intercept api to ml logisticregression --- .../spark/ml/classification/LogisticRegression.scala | 8 ++++++-- .../org/apache/spark/ml/param/sharedParams.scala | 12 ++++++++++++ .../ml/classification/LogisticRegressionSuite.scala | 9 +++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 49c00f77480e8..34625745dd0a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasThreshold + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold /** @@ -55,6 +55,9 @@ class LogisticRegression /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + /** @group setParam */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -67,7 +70,8 @@ class LogisticRegression } // Train model - val lr = new LogisticRegressionWithLBFGS + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(paramMap(fitIntercept)) lr.optimizer .setRegParam(paramMap(regParam)) .setNumIterations(paramMap(maxIter)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index 5d660d1e151a7..0739fdbfcbaae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -106,6 +106,18 @@ private[ml] trait HasProbabilityCol extends Params { def getProbabilityCol: String = get(probabilityCol) } +private[ml] trait HasFitIntercept extends Params { + /** + * param for fitting the intercept term, defaults to true + * @group param + */ + val fitIntercept: BooleanParam = + new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true)) + + /** @group getParam */ + def getFitIntercept: Boolean = get(fitIntercept) +} + private[ml] trait HasThreshold extends Params { /** * param for threshold in (binary) prediction diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b3d1bfcfbee0f..35d8c2e16c6cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol == "prediction") assert(lr.getRawPredictionCol == "rawPrediction") assert(lr.getProbabilityCol == "probability") + assert(lr.getFitIntercept == true) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") @@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getPredictionCol == "prediction") assert(model.getRawPredictionCol == "rawPrediction") assert(model.getProbabilityCol == "probability") + assert(model.intercept !== 0.0) + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept === 0.0) } test("logistic regression with setters") { From 8d2a36c0fdfbea9f58271ef6aeb89bb79b22cf62 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 7 Apr 2015 22:40:42 -0700 Subject: [PATCH 170/171] [SPARK-6754] Remove unnecessary TaskContextHelper The TaskContextHelper was originally necessary because TaskContext was written in Java, which does not have a way to specify that classes are package-private, so TaskContextHelper existed to work around this. Now that TaskContext has been re-written in Scala, this class is no longer necessary. rxin can you look at this? It looks like you missed this bit of cleanup when you moved TaskContext from Java to Scala in #4324 cc ScrapCodes and pwendell who added this originally. Author: Kay Ousterhout Closes #5402 from kayousterhout/SPARK-6754 and squashes the following commits: f089800 [Kay Ousterhout] [SPARK-6754] Remove unnecessary TaskContextHelper --- .../org/apache/spark/TaskContextHelper.scala | 29 ------------------- .../apache/spark/scheduler/DAGScheduler.scala | 4 +-- .../org/apache/spark/scheduler/Task.scala | 6 ++-- 3 files changed, 5 insertions(+), 34 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/TaskContextHelper.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala deleted file mode 100644 index 4636c4600a01a..0000000000000 --- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -/** - * This class exists to restrict the visibility of TaskContext setters. - */ -private [spark] object TaskContextHelper { - - def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) - - def unset(): Unit = TaskContext.unset() - -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c912520fded3b..508fe7b3303ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -645,13 +645,13 @@ class DAGScheduler( val split = rdd.partitions(job.partitions(0)) val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, attemptNumber = 0, runningLocally = true) - TaskContextHelper.setTaskContext(taskContext) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4d9f940813b8e..8b592867ee31d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -54,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(taskAttemptId: Long, attemptNumber: Int): T = { context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) - TaskContextHelper.setTaskContext(context) + TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() if (_killed) { @@ -64,7 +64,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } From 15e0d2bd1304db62fad286c1bb687e87c361e16c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Apr 2015 00:24:59 -0700 Subject: [PATCH 171/171] [SPARK-6765] Fix test code style for streaming. So we can turn style checker on for test code. Author: Reynold Xin Closes #5409 from rxin/test-style-streaming and squashes the following commits: 7aea69b [Reynold Xin] [SPARK-6765] Fix test code style for streaming. --- .../flume/FlumePollingStreamSuite.scala | 29 ++++++------ .../streaming/flume/FlumeStreamSuite.scala | 4 +- .../streaming/mqtt/MQTTStreamSuite.scala | 3 +- .../streaming/BasicOperationsSuite.scala | 6 ++- .../spark/streaming/CheckpointSuite.scala | 45 ++++++++++++++----- .../apache/spark/streaming/FailureSuite.scala | 4 +- .../spark/streaming/InputStreamsSuite.scala | 15 ++++--- .../streaming/ReceivedBlockHandlerSuite.scala | 4 +- .../streaming/ReceivedBlockTrackerSuite.scala | 6 ++- .../spark/streaming/ReceiverSuite.scala | 11 ++--- .../streaming/StreamingContextSuite.scala | 5 ++- .../streaming/StreamingListenerSuite.scala | 4 +- .../spark/streaming/TestSuiteBase.scala | 28 +++++++----- .../spark/streaming/UISeleniumSuite.scala | 3 +- .../streaming/WindowOperationsSuite.scala | 4 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 12 +++-- .../scheduler/JobGeneratorSuite.scala | 2 +- .../streaming/util/WriteAheadLogSuite.scala | 2 +- .../spark/streamingtest/ImplicitSuite.scala | 3 +- 19 files changed, 115 insertions(+), 75 deletions(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index e04d4088df7dc..2edea9b5b69ba 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -1,21 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + package org.apache.spark.streaming.flume import java.net.InetSocketAddress @@ -213,7 +212,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging assert(counter === totalEventsPerChannel * channels.size) } - def assertChannelIsEmpty(channel: MemoryChannel) = { + def assertChannelIsEmpty(channel: MemoryChannel): Unit = { val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") queueRemaining.setAccessible(true) val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 51d273af8da84..39e6754c81dbf 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -151,7 +151,9 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L } /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 24d78ecb3a97d..a19a72c58a705 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -139,7 +139,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { msgTopic.publish(message) } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(50) // wait for Spark streaming to consume something from the message queue + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index cf191715d29d6..87bc20f79c3cd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -171,7 +171,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("flatMapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + (s: DStream[String]) => { + s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)) + }, Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), true ) @@ -474,7 +476,7 @@ class BasicOperationsSuite extends TestSuiteBase { stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + def getInputFromSlice(fromMillis: Long, toMillis: Long): Set[Int] = { stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 91a2b2bba461d..54c30440a6e8d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -43,7 +43,7 @@ class CheckpointSuite extends TestSuiteBase { var ssc: StreamingContext = null - override def batchDuration = Milliseconds(500) + override def batchDuration: Duration = Milliseconds(500) override def beforeFunction() { super.beforeFunction() @@ -72,7 +72,7 @@ class CheckpointSuite extends TestSuiteBase { val input = (1 to 10).map(_ => Seq("a")).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.sum + state.getOrElse(0))) + Some(values.sum + state.getOrElse(0)) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) @@ -199,7 +199,12 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), Seq() ), 3 ) } @@ -212,7 +217,8 @@ class CheckpointSuite extends TestSuiteBase { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq - val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val output = Seq( + Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) @@ -236,7 +242,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[TextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -259,7 +271,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[NewTextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -298,7 +316,13 @@ class CheckpointSuite extends TestSuiteBase { output } }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -533,7 +557,8 @@ class CheckpointSuite extends TestSuiteBase { * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.getTimeMillis()) for (i <- 1 to numBatches.toInt) { @@ -543,7 +568,7 @@ class CheckpointSuite extends TestSuiteBase { logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.filter { dstream => + val outputStream = ssc.graph.getOutputStreams().filter { dstream => dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) @@ -552,4 +577,4 @@ class CheckpointSuite extends TestSuiteBase { private object CheckpointSuite extends Serializable { var batchThreeShouldBlockIndefinitely: Boolean = true -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 26435d8515815..0c4c06534a693 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -29,9 +29,9 @@ class FailureSuite extends TestSuiteBase with Logging { val directory = Utils.createTempDir() val numBatches = 30 - override def batchDuration = Milliseconds(1000) + override def batchDuration: Duration = Milliseconds(1000) - override def useManualClock = false + override def useManualClock: Boolean = false override def afterFunction() { Utils.deleteRecursively(directory) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 7ed6320a3d0bc..e6ac4975c5e68 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -52,7 +52,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -164,7 +164,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] val outputStream = new TestOutputStream(countStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -196,7 +196,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -204,7 +204,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) - //Thread.sleep(1000) + val inputIterator = input.toIterator for (i <- 0 until input.size) { // Enqueue more than 1 item per tick but they should dequeue one at a time @@ -239,7 +239,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -352,7 +352,8 @@ class TestServer(portToBind: Int = 0) extends Logging { logInfo("New connection") try { clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + val outputStream = new BufferedWriter( + new OutputStreamWriter(clientSocket.getOutputStream)) while(clientSocket.isConnected) { val msg = queue.poll(100, TimeUnit.MILLISECONDS) @@ -384,7 +385,7 @@ class TestServer(portToBind: Int = 0) extends Logging { def stop() { servingThread.interrupt() } - def port = serverSocket.getLocalPort + def port: Int = serverSocket.getLocalPort } /** This is a receiver to test multiple threads inserting data using block generator */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index ef4873de2f5a9..c090eaec2928d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -96,7 +96,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -120,7 +120,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 42fad769f0c1a..b63b37d9f9cef 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -228,7 +228,8 @@ class ReceivedBlockTrackerSuite * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. */ - def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) + : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new WriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => @@ -244,7 +245,8 @@ class ReceivedBlockTrackerSuite } /** Create batch allocation object from the given info */ - def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]) + : BatchAllocationEvent = { BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index aa20ad0b5374e..10c35cba8dc53 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -308,7 +308,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { val errors = new ArrayBuffer[Throwable] /** Check if all data structures are clean */ - def isAllEmpty = { + def isAllEmpty: Boolean = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && arrayBuffers.isEmpty && errors.isEmpty } @@ -320,24 +320,21 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { byteBuffers += bytes } def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { iterators += iterator } def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { arrayBuffers += arrayBuffer } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 2e5005ef6ff14..d1bbf39dc7897 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -213,7 +213,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = new StreamingContext(sc, Milliseconds(100)) var runningCount = 0 SlowTestReceiver.receivedAllRecords = false - //Create test receiver that sleeps in onStop() + // Create test receiver that sleeps in onStop() val totalNumRecords = 15 val recordsPerSecond = 1 val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) @@ -370,7 +370,8 @@ object TestReceiver { } /** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ -class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { var receivingThreadOption: Option[Thread] = None diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index f52562b0a0f73..852e8bb71d4f6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -38,8 +38,8 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { // To make sure that the processing start and end times in collected // information are different for successive batches - override def batchDuration = Milliseconds(100) - override def actuallyWait = true + override def batchDuration: Duration = Milliseconds(100) + override def actuallyWait: Boolean = true test("batch info reporting") { val ssc = setupStreams(input, operation) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 3565d621e8a6c..c3cae8aeb6d15 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -53,8 +53,9 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], val selectedInput = if (index < input.size) input(index) else Seq[T]() // lets us test cases where RDDs are not created - if (selectedInput == null) + if (selectedInput == null) { return None + } val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) @@ -104,7 +105,9 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], output.clear() } - def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + def toTestOutputStream: TestOutputStream[T] = { + new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + } } /** @@ -148,34 +151,34 @@ class BatchCounter(ssc: StreamingContext) { trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = this.getClass.getSimpleName + def framework: String = this.getClass.getSimpleName // Master for Spark context - def master = "local[2]" + def master: String = "local[2]" // Batch duration - def batchDuration = Seconds(1) + def batchDuration: Duration = Seconds(1) // Directory where the checkpoint data will be saved - lazy val checkpointDir = { + lazy val checkpointDir: String = { val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") dir.toString } // Number of partitions of the input parallel collections created for testing - def numInputPartitions = 2 + def numInputPartitions: Int = 2 // Maximum time to wait before the test times out - def maxWaitTimeMillis = 10000 + def maxWaitTimeMillis: Int = 10000 // Whether to use manual clock or not - def useManualClock = true + def useManualClock: Boolean = true // Whether to actually wait in real time before changing manual clock - def actuallyWait = false + def actuallyWait: Boolean = false - //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. + // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. val conf = new SparkConf() .setMaster(master) .setAppName(framework) @@ -346,7 +349,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + while (output.size < numExpectedOutput && + System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) ssc.awaitTerminationOrTimeout(50) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 87a0395efbf2a..998426ebb82e5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -32,7 +32,8 @@ import org.apache.spark._ /** * Selenium tests for the Spark Web UI. */ -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { +class UISeleniumSuite + extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index a5d2bb2fde16c..c39ad05f41520 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.storage.StorageLevel class WindowOperationsSuite extends TestSuiteBase { - override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer + override def maxWaitTimeMillis: Int = 20000 // large window tests can sometimes take longer - override def batchDuration = Seconds(1) // making sure its visible in this class + override def batchDuration: Duration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 7a6a2f3e577dd..c3602a5b73732 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -28,10 +28,13 @@ import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBloc import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} import org.apache.spark.util.Utils -class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { +class WriteAheadLogBackedBlockRDDSuite + extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() var sparkContext: SparkContext = null @@ -86,7 +89,8 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log * @param testStoreInBM Test whether blocks read from log are stored back into block manager */ - private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { + private def testRDD( + numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { val numBlocks = numPartitionsInBM + numPartitionsInWAL val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) @@ -110,7 +114,7 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w "Unexpected blocks in BlockManager" ) - // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( segments.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), @@ -152,6 +156,6 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w } private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 4150b60635ed6..7865b06c2e3c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -90,7 +90,7 @@ class JobGeneratorSuite extends TestSuiteBase { val receiverTracker = ssc.scheduler.receiverTracker // Get the blocks belonging to a batch - def getBlocksOfBatch(batchTime: Long) = { + def getBlocksOfBatch(batchTime: Long): Seq[ReceivedBlockInfo] = { receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 8335659667f22..a3919c43b95b4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -291,7 +291,7 @@ object WriteAheadLogSuite { manager } - /** Read data from a segments of a log file directly and return the list of byte buffers.*/ + /** Read data from a segments of a log file directly and return the list of byte buffers. */ def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala index d0bf328f2b74d..d66750463033a 100644 --- a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala @@ -25,7 +25,8 @@ package org.apache.spark.streamingtest */ class ImplicitSuite { - // We only want to test if `implict` works well with the compiler, so we don't need a real DStream. + // We only want to test if `implicit` works well with the compiler, + // so we don't need a real DStream. def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null def testToPairDStreamFunctions(): Unit = {