diff --git a/.gitignore b/.gitignore index a4ec12ca6b53f..7ec8d45e12c6b 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,4 @@ metastore_db/ metastore/ warehouse/ TempStatsStore/ +sql/hive-thriftserver/test_warehouses diff --git a/LICENSE b/LICENSE index 65e1f480d9b14..76a3601c66918 100644 --- a/LICENSE +++ b/LICENSE @@ -272,7 +272,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ======================================================================== -For Py4J (python/lib/py4j0.7.egg and files in assembly/lib/net/sf/py4j): +For Py4J (python/lib/py4j-0.8.2.1-src.zip) ======================================================================== Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. @@ -532,7 +532,7 @@ The following components are provided under a BSD-style license. See project lin (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.8.1 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.8.2.1 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (ISC/BSD License) jbcrypt (org.mindrot:jbcrypt:0.3m - http://www.mindrot.org/) diff --git a/assembly/pom.xml b/assembly/pom.xml index 567a8dd2a0d94..703f15925bc44 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -165,6 +165,16 @@ + + hive-thriftserver + + + org.apache.spark + spark-hive-thriftserver_${scala.binary.version} + ${project.version} + + + spark-ganglia-lgpl diff --git a/bagel/pom.xml b/bagel/pom.xml index 90c4b095bb611..bd51b112e26fa 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-bagel_2.10 - bagel + bagel jar Spark Project Bagel diff --git a/bin/beeline b/bin/beeline new file mode 100755 index 0000000000000..09fe366c609fa --- /dev/null +++ b/bin/beeline @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +# Find the java binary +if [ -n "${JAVA_HOME}" ]; then + RUNNER="${JAVA_HOME}/bin/java" +else + if [ `command -v java` ]; then + RUNNER="java" + else + echo "JAVA_HOME is not set" >&2 + exit 1 + fi +fi + +# Compute classpath using external script +classpath_output=$($FWDIR/bin/compute-classpath.sh) +if [[ "$?" != "0" ]]; then + echo "$classpath_output" + exit 1 +else + CLASSPATH=$classpath_output +fi + +CLASS="org.apache.hive.beeline.BeeLine" +exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@" diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index e81e8c060cb98..16b794a1592e8 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -52,6 +52,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" fi diff --git a/bin/pyspark b/bin/pyspark index 69b056fe28f2c..39a20e2a24a3c 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -52,7 +52,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH +export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP=$PYTHONSTARTUP diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 0ef9eea95342e..2c4b08af8d4c3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -45,7 +45,7 @@ rem Figure out which Python to use. if [%PYSPARK_PYTHON%] == [] set PYSPARK_PYTHON=python set PYTHONPATH=%FWDIR%python;%PYTHONPATH% -set PYTHONPATH=%FWDIR%python\lib\py4j-0.8.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%FWDIR%python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%FWDIR%python\pyspark\shell.py diff --git a/bin/spark-shell b/bin/spark-shell index 850e9507ec38f..756c8179d12b6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -46,11 +46,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 4b9708a8c03f3..b56d69801171c 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell %* --class org.apache.spark.repl.Main +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql new file mode 100755 index 0000000000000..bba7f897b19bc --- /dev/null +++ b/bin/spark-sql @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL CLI + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/spark-sql [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/core/pom.xml b/core/pom.xml index 1054cec4d77bb..7c60cf10c3dc2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-core_2.10 - core + core jar Spark Project Core @@ -192,8 +192,8 @@ org.tachyonproject - tachyon - 0.4.1-thrift + tachyon-client + 0.5.0 org.apache.hadoop @@ -262,6 +262,11 @@ asm test + + junit + junit + test + com.novocode junit-interface @@ -275,7 +280,7 @@ net.sf.py4j py4j - 0.8.1 + 0.8.2.1 diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ff0ca11749d42..79c9c451d273d 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -56,18 +56,23 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) - // TODO: Make this non optional in a future release - Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) - Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled) + // Update task metrics if context is not null + // TODO: Make context non optional in a future release + Option(context).foreach { c => + c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled + c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled + } combiners.iterator } } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") - def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = + def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = { + def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) + : Iterator[(K, C)] = + { if (!externalSorting) { val combiners = new AppendOnlyMap[K,C] var kc: Product2[K, C] = null @@ -85,9 +90,12 @@ case class Aggregator[K, V, C] ( val pair = iter.next() combiners.insert(pair._1, pair._2) } - // TODO: Make this non optional in a future release - Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) - Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled) + // Update task metrics if context is not null + // TODO: Make context non-optional in a future release + Option(context).foreach { c => + c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled + c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled + } combiners.iterator } } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index f010c03223ef4..ab2594cfc02eb 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -19,7 +19,6 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -28,29 +27,35 @@ import org.apache.spark.shuffle.ShuffleHandle * Base class for dependencies. */ @DeveloperApi -abstract class Dependency[T](val rdd: RDD[T]) extends Serializable +abstract class Dependency[T] extends Serializable { + def rdd: RDD[T] +} /** * :: DeveloperApi :: - * Base class for dependencies where each partition of the parent RDD is used by at most one - * partition of the child RDD. Narrow dependencies allow for pipelined execution. + * Base class for dependencies where each partition of the child RDD depends on a small number + * of partitions of the parent RDD. Narrow dependencies allow for pipelined execution. */ @DeveloperApi -abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { +abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { /** * Get the parent partitions for a child partition. * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ def getParents(partitionId: Int): Seq[Int] + + override def rdd: RDD[T] = _rdd } /** * :: DeveloperApi :: - * Represents a dependency on the output of a shuffle stage. - * @param rdd the parent RDD + * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle, + * the RDD is transient since we don't need it on the executor side. + * + * @param _rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will @@ -58,21 +63,22 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { */ @DeveloperApi class ShuffleDependency[K, V, C]( - @transient rdd: RDD[_ <: Product2[K, V]], + @transient _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, - val mapSideCombine: Boolean = false, - val sortOrder: Option[SortOrder] = None) - extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { + val mapSideCombine: Boolean = false) + extends Dependency[Product2[K, V]] { + + override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] - val shuffleId: Int = rdd.context.newShuffleId() + val shuffleId: Int = _rdd.context.newShuffleId() - val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( - shuffleId, rdd.partitions.size, this) + val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( + shuffleId, _rdd.partitions.size, this) - rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala new file mode 100644 index 0000000000000..24ccce21b62ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import akka.actor.Actor +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.scheduler.TaskScheduler + +/** + * A heartbeat from executors to the driver. This is a shared message used by several internal + * components to convey liveness or execution information for in-progress tasks. + */ +private[spark] case class Heartbeat( + executorId: String, + taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + blockManagerId: BlockManagerId) + +private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) + +/** + * Lives in the driver to receive heartbeats from executors.. + */ +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor { + override def receive = { + case Heartbeat(executorId, taskMetrics, blockManagerId) => + val response = HeartbeatResponse( + !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) + sender ! response + } +} diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 50d8e93e1f0d7..807ef3e9c9d60 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -45,10 +45,7 @@ trait Logging { initializeIfNecessary() var className = this.getClass.getName // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - log_ = LoggerFactory.getLogger(className) + log_ = LoggerFactory.getLogger(className.stripSuffix("$")) } log_ } @@ -110,23 +107,27 @@ trait Logging { } private def initializeLogging() { - // If Log4j is being used, but is not initialized, load a default properties file - val binder = StaticLoggerBinder.getSingleton - val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory") - val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4jInitialized && usingLog4j) { + // Don't use a logger in here, as this is itself occurring during initialization of a logger + // If Log4j 1.2 is being used, but is not initialized, load a default properties file + val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr + // This distinguishes the log4j 1.2 binding, currently + // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently + // org.apache.logging.slf4j.Log4jLoggerFactory + val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized && usingLog4j12) { val defaultLogProps = "org/apache/spark/log4j-defaults.properties" Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { case Some(url) => PropertyConfigurator.configure(url) - log.info(s"Using Spark's default log4j profile: $defaultLogProps") + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") case None => System.err.println(s"Spark was unable to load $defaultLogProps") } } Logging.initialized = true - // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // Force a call into slf4j to initialize it. Avoids this happening from multiple threads // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html log } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 52c018baa5f7b..37053bb6f37ad 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,11 +19,15 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import scala.reflect.ClassTag +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.{ClassTag, classTag} +import scala.util.hashing.byteswap32 -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} +import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. @@ -103,26 +107,49 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ascending: Boolean = true) extends Partitioner { + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. + require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions private var rangeBounds: Array[K] = { - if (partitions == 1) { - Array() + if (partitions <= 1) { + Array.empty } else { - val rddSize = rdd.count() - val maxSampleSize = partitions * 20.0 - val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted - if (rddSample.length == 0) { - Array() + // This is the sample size we need to have roughly balanced output partitions, capped at 1M. + val sampleSize = math.min(20.0 * partitions, 1e6) + // Assume the input partitions are roughly balanced and over-sample a little bit. + val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt + val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) + if (numItems == 0L) { + Array.empty } else { - val bounds = new Array[K](partitions - 1) - for (i <- 0 until partitions - 1) { - val index = (rddSample.length - 1) * (i + 1) / partitions - bounds(i) = rddSample(index) + // If a partition contains much more than the average number of items, we re-sample from it + // to ensure that enough items are collected from that partition. + val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) + val candidates = ArrayBuffer.empty[(K, Float)] + val imbalancedPartitions = mutable.Set.empty[Int] + sketched.foreach { case (idx, n, sample) => + if (fraction * n > sampleSizePerPartition) { + imbalancedPartitions += idx + } else { + // The weight is 1 over the sampling probability. + val weight = (n.toDouble / sample.size).toFloat + for (key <- sample) { + candidates += ((key, weight)) + } + } + } + if (imbalancedPartitions.nonEmpty) { + // Re-sample imbalanced partitions with the desired sampling probability. + val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains) + val seed = byteswap32(-rdd.id - 1) + val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() + val weight = (1.0 / fraction).toFloat + candidates ++= reSampled.map(x => (x, weight)) } - bounds + RangePartitioner.determineBounds(candidates, partitions) } } } @@ -212,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } } + +private[spark] object RangePartitioner { + + /** + * Sketches the input RDD via reservoir sampling on each partition. + * + * @param rdd the input RDD to sketch + * @param sampleSizePerPartition max sample size per partition + * @return (total number of items, an array of (partitionId, number of items, sample)) + */ + def sketch[K:ClassTag]( + rdd: RDD[K], + sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + val shift = rdd.id + // val classTagK = classTag[K] // to avoid serializing the entire partitioner object + val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => + val seed = byteswap32(idx ^ (shift << 16)) + val (sample, n) = SamplingUtils.reservoirSampleAndCount( + iter, sampleSizePerPartition, seed) + Iterator((idx, n, sample)) + }.collect() + val numItems = sketched.map(_._2.toLong).sum + (numItems, sketched) + } + + /** + * Determines the bounds for range partitioning from candidates with weights indicating how many + * items each represents. Usually this is 1 over the probability used to sample this candidate. + * + * @param candidates unordered candidates with weights + * @param partitions number of partitions + * @return selected bounds + */ + def determineBounds[K:Ordering:ClassTag]( + candidates: ArrayBuffer[(K, Float)], + partitions: Int): Array[K] = { + val ordering = implicitly[Ordering[K]] + val ordered = candidates.sortBy(_._1) + val numCandidates = ordered.size + val sumWeights = ordered.map(_._2.toDouble).sum + val step = sumWeights / partitions + var cumWeight = 0.0 + var target = step + val bounds = ArrayBuffer.empty[K] + var i = 0 + var j = 0 + var previousBound = Option.empty[K] + while ((i < numCandidates) && (j < partitions - 1)) { + val (key, weight) = ordered(i) + cumWeight += weight + if (cumWeight > target) { + // Skip duplicate values. + if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { + bounds += key + target += step + j += 1 + previousBound = Some(key) + } + } + i += 1 + } + bounds.toArray + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8ce4b91cae8ae..38700847c80f4 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -40,6 +40,8 @@ import scala.collection.mutable.HashMap */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { + import SparkConf._ + /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) @@ -198,7 +200,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * * E.g. spark.akka.option.x.y.x = "value" */ - getAll.filter {case (k, v) => k.startsWith("akka.")} + getAll.filter { case (k, _) => isAkkaConf(k) } /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) @@ -292,3 +294,21 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } } + +private[spark] object SparkConf { + /** + * Return whether the given config is an akka config (e.g. akka.actor.provider). + * Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout). + */ + def isAkkaConf(name: String): Boolean = name.startsWith("akka.") + + /** + * Return whether the given config should be passed to an executor on start-up. + * + * Certain akka and authentication configs are required of the executor when it connects to + * the scheduler, while the rest of the spark configs can be inherited from the driver later. + */ + def isExecutorStartupConf(name: String): Boolean = { + isAkkaConf(name) || name.startsWith("spark.akka") || name.startsWith("spark.auth") + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e6addeaf04a8..368835a867493 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary +import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast @@ -289,7 +290,7 @@ class SparkContext(config: SparkConf) extends Logging { value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => executorEnvs("SPARK_PREPEND_CLASSES") = v } // The Mesos scheduler backend relies on this environment variable to set executor memory. @@ -307,6 +308,8 @@ class SparkContext(config: SparkConf) extends Logging { // Create and start the scheduler private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master) + private val heartbeatReceiver = env.actorSystem.actorOf( + Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) @@ -455,7 +458,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ - def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -990,15 +993,15 @@ class SparkContext(config: SparkConf) extends Logging { val dagSchedulerCopy = dagScheduler dagScheduler = null if (dagSchedulerCopy != null) { + env.metricsSystem.report() metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) dagSchedulerCopy.stop() taskScheduler = null // TODO: Cache.stop()? env.stop() SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() listenerBus.stop() eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") @@ -1205,10 +1208,10 @@ class SparkContext(config: SparkConf) extends Logging { /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) - * If checkSerializable is set, clean will also proactively - * check to see if f is serializable and throw a SparkException + * If checkSerializable is set, clean will also proactively + * check to see if f is serializable and throw a SparkException * if not. - * + * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability * @throws SparkException if checkSerializable is set but f is not @@ -1453,9 +1456,9 @@ object SparkContext extends Logging { /** Creates a task scheduler based on a given master URL. Extracted for testing. */ private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters @@ -1485,8 +1488,12 @@ object SparkContext extends Logging { scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + def localCpuCount = Runtime.getRuntime.availableProcessors() + // local[*, M] means the number of cores on the computer with M failures + // local[N, M] means exactly N threads with M failures + val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) + val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6ee731b22c03c..92c809d854167 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -193,13 +193,7 @@ object SparkEnv extends Logging { logInfo("Registering " + name) actorSystem.actorOf(Props(newActor), name = name) } else { - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") - val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" - val timeout = AkkaUtils.lookupTimeout(conf) - logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + AkkaUtils.makeDriverRef(name, conf, actorSystem) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala new file mode 100644 index 0000000000000..0ae0b4ec042e2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import org.apache.hadoop.mapred.InputSplit + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaSparkContext._ +import org.apache.spark.api.java.function.{Function2 => JFunction2} +import org.apache.spark.rdd.HadoopRDD + +@DeveloperApi +class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) + (implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V]) + extends JavaPairRDD[K, V](rdd) { + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[R]( + f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], + preservesPartitioning: Boolean = false): JavaRDD[R] = { + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + preservesPartitioning)(fakeClassTag))(fakeClassTag) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala new file mode 100644 index 0000000000000..ec4f3964d75e0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import org.apache.hadoop.mapreduce.InputSplit + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaSparkContext._ +import org.apache.spark.api.java.function.{Function2 => JFunction2} +import org.apache.spark.rdd.NewHadoopRDD + +@DeveloperApi +class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) + (implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V]) + extends JavaPairRDD[K, V](rdd) { + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[R]( + f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], + preservesPartitioning: Boolean = false): JavaRDD[R] = { + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + preservesPartitioning)(fakeClassTag))(fakeClassTag) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 4f3081433a542..76d4193e96aea 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList} +import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} import scala.collection.JavaConversions._ @@ -122,13 +122,80 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over + * the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + exact: Boolean, + seed: Long): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed)) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over + * the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. + * + * Use Utils.random.nextLong as the default seed for the random number generator + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + exact: Boolean): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * Produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via + * simple random sampling. + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + seed: Long): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, false, seed) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * Produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via + * simple random sampling. + * + * Use Utils.random.nextLong as the default seed for the random number generator + */ + def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, false, Utils.random.nextLong) + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). @@ -716,6 +783,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) sortByKey(comp, ascending) } + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] + sortByKey(comp, ascending, numPartitions) + } + /** * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling * `collect` or `save` on the resulting RDD will return or output an ordered list of records diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8a5f8088a05ca..d9d1c5955ca99 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -34,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns @@ -294,7 +294,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions)) + val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** @@ -314,7 +315,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) + val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat. @@ -333,7 +335,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)) + val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat @@ -351,8 +354,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopFile(path, - inputFormatClass, keyClass, valueClass)) + val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** @@ -372,7 +375,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork conf: Configuration): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(kClass) implicit val ctagV: ClassTag[V] = ClassTag(vClass) - new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)) + val rdd = sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf) + new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]]) } /** @@ -391,7 +395,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork vClass: Class[V]): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(kClass) implicit val ctagV: ClassTag[V] = ClassTag(vClass) - new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) + val rdd = sc.newAPIHadoopRDD(conf, fClass, kClass, vClass) + new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]]) } /** Build the union of two or more RDDs. */ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index adaa1ef6cf9ff..f3b05e1243045 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.python +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.Logging +import org.apache.spark.{Logging, SerializableWritable, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import scala.util.{Failure, Success, Try} @@ -31,13 +32,14 @@ import org.apache.spark.annotation.Experimental * transformation code by overriding the convert method. */ @Experimental -trait Converter[T, U] extends Serializable { +trait Converter[T, + U] extends Serializable { def convert(obj: T): U } private[python] object Converter extends Logging { - def getInstance(converterClass: Option[String]): Converter[Any, Any] = { + def getInstance(converterClass: Option[String], + defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]] @@ -49,7 +51,7 @@ private[python] object Converter extends Logging { logError(s"Failed to load converter: $cc") throw err } - }.getOrElse { new DefaultConverter } + }.getOrElse { defaultConverter } } } @@ -57,7 +59,9 @@ private[python] object Converter extends Logging { * A converter that handles conversion of common [[org.apache.hadoop.io.Writable]] objects. * Other objects are passed through without conversion. */ -private[python] class DefaultConverter extends Converter[Any, Any] { +private[python] class WritableToJavaConverter( + conf: Broadcast[SerializableWritable[Configuration]], + batchSize: Int) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -72,17 +76,30 @@ private[python] class DefaultConverter extends Converter[Any, Any] { case fw: FloatWritable => fw.get() case t: Text => t.toString case bw: BooleanWritable => bw.get() - case byw: BytesWritable => byw.getBytes + case byw: BytesWritable => + val bytes = new Array[Byte](byw.getLength) + System.arraycopy(byw.getBytes(), 0, bytes, 0, byw.getLength) + bytes case n: NullWritable => null - case aw: ArrayWritable => aw.get().map(convertWritable(_)) - case mw: MapWritable => mapAsJavaMap(mw.map { case (k, v) => - (convertWritable(k), convertWritable(v)) - }.toMap) + case aw: ArrayWritable => + // Due to erasure, all arrays appear as Object[] and they get pickled to Python tuples. + // Since we can't determine element types for empty arrays, we will not attempt to + // convert to primitive arrays (which get pickled to Python arrays). Users may want + // write custom converters for arrays if they know the element types a priori. + aw.get().map(convertWritable(_)) + case mw: MapWritable => + val map = new java.util.HashMap[Any, Any]() + mw.foreach { case (k, v) => + map.put(convertWritable(k), convertWritable(v)) + } + map + case w: Writable => + if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w case other => other } } - def convert(obj: Any): Any = { + override def convert(obj: Any): Any = { obj match { case writable: Writable => convertWritable(writable) @@ -92,6 +109,47 @@ private[python] class DefaultConverter extends Converter[Any, Any] { } } +/** + * A converter that converts common types to [[org.apache.hadoop.io.Writable]]. Note that array + * types are not supported since the user needs to subclass [[org.apache.hadoop.io.ArrayWritable]] + * to set the type properly. See [[org.apache.spark.api.python.DoubleArrayWritable]] and + * [[org.apache.spark.api.python.DoubleArrayToWritableConverter]] for an example. They are used in + * PySpark RDD `saveAsNewAPIHadoopFile` doctest. + */ +private[python] class JavaToWritableConverter extends Converter[Any, Writable] { + + /** + * Converts common data types to [[org.apache.hadoop.io.Writable]]. Note that array types are not + * supported out-of-the-box. + */ + private def convertToWritable(obj: Any): Writable = { + import collection.JavaConversions._ + obj match { + case i: java.lang.Integer => new IntWritable(i) + case d: java.lang.Double => new DoubleWritable(d) + case l: java.lang.Long => new LongWritable(l) + case f: java.lang.Float => new FloatWritable(f) + case s: java.lang.String => new Text(s) + case b: java.lang.Boolean => new BooleanWritable(b) + case aob: Array[Byte] => new BytesWritable(aob) + case null => NullWritable.get() + case map: java.util.Map[_, _] => + val mapWritable = new MapWritable() + map.foreach { case (k, v) => + mapWritable.put(convertToWritable(k), convertToWritable(v)) + } + mapWritable + case other => throw new SparkException( + s"Data of type ${other.getClass.getName} cannot be used") + } + } + + override def convert(obj: Any): Writable = obj match { + case writable: Writable => writable + case other => convertToWritable(other) + } +} + /** Utilities for working with Python objects <-> Hadoop-related objects */ private[python] object PythonHadoopUtil { @@ -118,7 +176,7 @@ private[python] object PythonHadoopUtil { /** * Converts an RDD of key-value pairs, where key and/or value could be instances of - * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)] + * [[org.apache.hadoop.io.Writable]], into an RDD of base types, or vice versa. */ def convertRDD[K, V](rdd: RDD[(K, V)], keyConverter: Converter[Any, Any], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d87783efd2d01..94d666aa92025 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,15 +23,18 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.language.existentials import scala.reflect.ClassTag import scala.util.Try import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.{InputFormat, JobConf} -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf} +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -365,19 +368,17 @@ private[spark] object PythonRDD extends Logging { valueClassMaybeNull: String, keyConverterClass: String, valueConverterClass: String, - minSplits: Int) = { + minSplits: Int, + batchSize: Int) = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -394,17 +395,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { - val conf = PythonHadoopUtil.mapToConf(confAsMap) - val baseConf = sc.hadoopConfiguration() - val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf) + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { + val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -421,15 +421,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } private def newAPIHadoopRDDFromClassNames[K, V, F <: NewInputFormat[K, V]]( @@ -439,18 +440,14 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - val fc = fcm.runtimeClass.asInstanceOf[Class[F]] - val rdd = if (path.isDefined) { + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + if (path.isDefined) { sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) } else { sc.sc.newAPIHadoopRDD[K, V, F](conf, fc, kc, vc) } - rdd } /** @@ -467,17 +464,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { - val conf = PythonHadoopUtil.mapToConf(confAsMap) - val baseConf = sc.hadoopConfiguration() - val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf) + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { + val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -494,15 +490,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } private def hadoopRDDFromClassNames[K, V, F <: InputFormat[K, V]]( @@ -512,18 +509,14 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - val fc = fcm.runtimeClass.asInstanceOf[Class[F]] - val rdd = if (path.isDefined) { + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + if (path.isDefined) { sc.sc.hadoopFile(path.get, fc, kc, vc) } else { sc.sc.hadoopRDD(new JobConf(conf), fc, kc, vc) } - rdd } def writeUTF(str: String, dataOut: DataOutputStream) { @@ -544,23 +537,170 @@ private[spark] object PythonRDD extends Logging { } /** - * Convert an RDD of serialized Python dictionaries to Scala Maps + * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). + * It is only used by pyspark.sql. * TODO: Support more Python types. */ def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler - // TODO: Figure out why flatMap is necessay for pyspark iter.flatMap { row => unpickle.loads(row) match { + // in case of objects are pickled in batch mode case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // Incase the partition doesn't have a collection + // not in batch mode case obj: JMap[String @unchecked, _] => Seq(obj.toMap) } } } } + private def getMergedConf(confAsMap: java.util.HashMap[String, String], + baseConf: Configuration): Configuration = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + PythonHadoopUtil.mergeConfs(baseConf, conf) + } + + private def inferKeyValueTypes[K, V](rdd: RDD[(K, V)], keyConverterClass: String = null, + valueConverterClass: String = null): (Class[_], Class[_]) = { + // Peek at an element to figure out key/value types. Since Writables are not serializable, + // we cannot call first() on the converted RDD. Instead, we call first() on the original RDD + // and then convert locally. + val (key, value) = rdd.first() + val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + (kc.convert(key).getClass, vc.convert(value).getClass) + } + + private def getKeyValueTypes(keyClass: String, valueClass: String): + Option[(Class[_], Class[_])] = { + for { + k <- Option(keyClass) + v <- Option(valueClass) + } yield (Class.forName(k), Class.forName(v)) + } + + private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String, + defaultConverter: Converter[Any, Any]): (Converter[Any, Any], Converter[Any, Any]) = { + val keyConverter = Converter.getInstance(Option(keyConverterClass), defaultConverter) + val valueConverter = Converter.getInstance(Option(valueConverterClass), defaultConverter) + (keyConverter, valueConverter) + } + + /** + * Convert an RDD of key-value pairs from internal types to serializable types suitable for + * output, or vice versa. + */ + private def convertRDD[K, V](rdd: RDD[(K, V)], + keyConverterClass: String, + valueConverterClass: String, + defaultConverter: Converter[Any, Any]): RDD[(Any, Any)] = { + val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass, + defaultConverter) + PythonHadoopUtil.convertRDD(rdd, kc, vc) + } + + /** + * Output a Python RDD of key-value pairs as a Hadoop SequenceFile using the Writable types + * we convert from the RDD's key and value types. Note that keys and values can't be + * [[org.apache.hadoop.io.Writable]] types already, since Writables are not Java + * `Serializable` and we can't peek at them. The `path` can be on any Hadoop file system. + */ + def saveAsSequenceFile[K, V, C <: CompressionCodec]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + compressionCodecClass: String) = { + saveAsHadoopFile( + pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat", + null, null, null, null, new java.util.HashMap(), compressionCodecClass) + } + + /** + * Output a Python RDD of key-value pairs to any Hadoop file system, using old Hadoop + * `OutputFormat` in mapred package. Keys and values are converted to suitable output + * types using either user specified converters or, if not specified, + * [[org.apache.spark.api.python.JavaToWritableConverter]]. Post-conversion types + * `keyClass` and `valueClass` are automatically inferred if not specified. The passed-in + * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of + * this RDD. + */ + def saveAsHadoopFile[K, V, F <: OutputFormat[_, _], C <: CompressionCodec]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + outputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String], + compressionCodecClass: String) = { + val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) + val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( + inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) + val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) + val codec = Option(compressionCodecClass).map(Class.forName(_).asInstanceOf[Class[C]]) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + } + + /** + * Output a Python RDD of key-value pairs to any Hadoop file system, using new Hadoop + * `OutputFormat` in mapreduce package. Keys and values are converted to suitable output + * types using either user specified converters or, if not specified, + * [[org.apache.spark.api.python.JavaToWritableConverter]]. Post-conversion types + * `keyClass` and `valueClass` are automatically inferred if not specified. The passed-in + * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of + * this RDD. + */ + def saveAsNewAPIHadoopFile[K, V, F <: NewOutputFormat[_, _]]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + outputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String]) = { + val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) + val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( + inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) + val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + converted.saveAsNewAPIHadoopFile(path, kc, vc, fc, mergedConf) + } + + /** + * Output a Python RDD of key-value pairs to any Hadoop file system, using a Hadoop conf + * converted from the passed-in `confAsMap`. The conf should set relevant output params ( + * e.g., output path, output format, etc), in the same way as it would be configured for + * a Hadoop MapReduce job. Both old and new Hadoop OutputFormat APIs are supported + * (mapred vs. mapreduce). Keys/values are converted for output using either user specified + * converters or, by default, [[org.apache.spark.api.python.JavaToWritableConverter]]. + */ + def saveAsHadoopDataset[K, V]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + confAsMap: java.util.HashMap[String, String], + keyConverterClass: String, + valueConverterClass: String, + useNewAPI: Boolean) = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized), + keyConverterClass, valueConverterClass, new JavaToWritableConverter) + if (useNewAPI) { + converted.saveAsNewAPIHadoopDataset(conf) + } else { + converted.saveAsHadoopDataset(new JobConf(conf)) + } + } + /** * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by * PySpark. @@ -591,19 +731,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) + /** + * We try to reuse a single Socket to transfer accumulator updates, as they are all added + * by the DAGScheduler's single-threaded actor anyway. + */ + @transient var socket: Socket = _ + + def openSocket(): Socket = synchronized { + if (socket == null || socket.isClosed) { + socket = new Socket(serverHost, serverPort) + } + socket + } + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = { + : JList[Array[Byte]] = synchronized { if (serverHost == null) { // This happens on the worker node, where we just want to remember all the updates val1.addAll(val2) val1 } else { // This happens on the master, where we pass the updates to Python through a socket - val socket = new Socket(serverHost, serverPort) - // SPARK-2282: Immediately reuse closed sockets because we create one per task. - socket.setReuseAddress(true) + val socket = openSocket() val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) @@ -617,7 +768,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: if (byteRead == -1) { throw new SparkException("EOF reached before Python server acknowledged") } - socket.close() null } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 6d3e257c4d5df..52c70712eea3d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -29,7 +29,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.1-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.2.1-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 9a012e7254901..efc9009c088a8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -17,13 +17,14 @@ package org.apache.spark.api.python -import scala.util.Try -import org.apache.spark.rdd.RDD -import org.apache.spark.Logging -import scala.util.Success +import scala.collection.JavaConversions._ import scala.util.Failure -import net.razorvine.pickle.Pickler +import scala.util.Try +import net.razorvine.pickle.{Unpickler, Pickler} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ private[python] object SerDeUtil extends Logging { @@ -65,20 +66,52 @@ private[python] object SerDeUtil extends Logging { * by PySpark. By default, if serialization fails, toString is called and the string * representation is serialized */ - def rddToPython(rdd: RDD[(Any, Any)]): RDD[Array[Byte]] = { + def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) rdd.mapPartitions { iter => val pickle = new Pickler - iter.map { case (k, v) => - if (keyFailed && valueFailed) { - pickle.dumps(Array(k.toString, v.toString)) - } else if (keyFailed) { - pickle.dumps(Array(k.toString, v)) - } else if (!keyFailed && valueFailed) { - pickle.dumps(Array(k, v.toString)) + val cleaned = iter.map { case (k, v) => + val key = if (keyFailed) k.toString else k + val value = if (valueFailed) v.toString else v + Array[Any](key, value) + } + if (batchSize > 1) { + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + } else { + cleaned.map(pickle.dumps(_)) + } + } + } + + /** + * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. + */ + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def isPair(obj: Any): Boolean = { + Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + obj.asInstanceOf[Array[_]].length == 2 + } + pyRDD.mapPartitions { iter => + val unpickle = new Unpickler + val unpickled = + if (batchSerialized) { + iter.flatMap { batch => + unpickle.loads(batch) match { + case objs: java.util.List[_] => collectionAsScalaIterable(objs) + case other => throw new SparkException( + s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") + } + } } else { - pickle.dumps(Array(k, v)) + iter.map(unpickle.loads(_)) } + unpickled.map { + case obj if isPair(obj) => + // we only accept (K, V) + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index f0e3fb9aff5a0..d11db978b842e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -17,15 +17,16 @@ package org.apache.spark.api.python -import org.apache.spark.SparkContext -import org.apache.hadoop.io._ -import scala.Array import java.io.{DataOutput, DataInput} +import java.nio.charset.Charset + +import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.{SparkContext, SparkException} /** - * A class to test MsgPack serialization on the Scala side, that will be deserialized + * A class to test Pyrolite serialization on the Scala side, that will be deserialized * in Python * @param str * @param int @@ -54,7 +55,13 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } } -class TestConverter extends Converter[Any, Any] { +private[python] class TestInputKeyConverter extends Converter[Any, Any] { + override def convert(obj: Any) = { + obj.asInstanceOf[IntWritable].get().toChar + } +} + +private[python] class TestInputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ override def convert(obj: Any) = { val m = obj.asInstanceOf[MapWritable] @@ -62,6 +69,38 @@ class TestConverter extends Converter[Any, Any] { } } +private[python] class TestOutputKeyConverter extends Converter[Any, Any] { + override def convert(obj: Any) = { + new Text(obj.asInstanceOf[Int].toString) + } +} + +private[python] class TestOutputValueConverter extends Converter[Any, Any] { + import collection.JavaConversions._ + override def convert(obj: Any) = { + new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) + } +} + +private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable]) + +private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] { + override def convert(obj: Any) = obj match { + case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] => + val daw = new DoubleArrayWritable + daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_))) + daw + case other => throw new SparkException(s"Data of type $other is not supported") + } +} + +private[python] class WritableToDoubleArrayConverter extends Converter[Any, Array[Double]] { + override def convert(obj: Any): Array[Double] = obj match { + case daw : DoubleArrayWritable => daw.get().map(_.asInstanceOf[DoubleWritable].get()) + case other => throw new SparkException(s"Data of type $other is not supported") + } +} + /** * This object contains method to generate SequenceFile test data and write it to a * given directory (probably a temp directory) @@ -97,7 +136,8 @@ object WriteInputFormatTestDataGenerator { sc.parallelize(intKeys).saveAsSequenceFile(intPath) sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath) sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath) - sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes) }).saveAsSequenceFile(bytesPath) + sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) } + ).saveAsSequenceFile(bytesPath) val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false)) sc.parallelize(bools).saveAsSequenceFile(boolPath) sc.parallelize(intKeys).map{ case (k, v) => @@ -106,19 +146,20 @@ object WriteInputFormatTestDataGenerator { // Create test data for ArrayWritable val data = Seq( - (1, Array(1.0, 2.0, 3.0)), + (1, Array()), (2, Array(3.0, 4.0, 5.0)), (3, Array(4.0, 5.0, 6.0)) ) sc.parallelize(data, numSlices = 2) .map{ case (k, v) => - (new IntWritable(k), new ArrayWritable(classOf[DoubleWritable], v.map(new DoubleWritable(_)))) - }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, ArrayWritable]](arrPath) + val va = new DoubleArrayWritable + va.set(v.map(new DoubleWritable(_))) + (new IntWritable(k), va) + }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, DoubleArrayWritable]](arrPath) // Create test data for MapWritable, with keys DoubleWritable and values Text val mapData = Seq( - (1, Map(2.0 -> "aa")), - (2, Map(3.0 -> "bb")), + (1, Map()), (2, Map(1.0 -> "cc")), (3, Map(2.0 -> "dd")), (2, Map(1.0 -> "aa")), @@ -126,9 +167,9 @@ object WriteInputFormatTestDataGenerator { ) sc.parallelize(mapData, numSlices = 2).map{ case (i, m) => val mw = new MapWritable() - val k = m.keys.head - val v = m.values.head - mw.put(new DoubleWritable(k), new Text(v)) + m.foreach { case (k, v) => + mw.put(new DoubleWritable(k), new Text(v)) + } (new IntWritable(i), mw) }.saveAsSequenceFile(mapPath) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c371dc3a51c73..17c507af2652d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy -import scala.collection.JavaConversions._ -import scala.collection.mutable.Map import scala.concurrent._ import akka.actor._ @@ -50,9 +48,6 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends // TODO: We could add an env variable here and intercept it in `sc.addJar` that would // truncate filesystem paths similar to what YARN does. For now, we just require // people call `addJar` assuming the jar is in the same directory. - val env = Map[String, String]() - System.getenv().foreach{case (k, v) => env(k) = v} - val mainClass = "org.apache.spark.deploy.worker.DriverWrapper" val classPathConf = "spark.driver.extraClassPath" @@ -65,10 +60,13 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends cp.split(java.io.File.pathSeparator) } - val javaOptionsConf = "spark.driver.extraJavaOptions" - val javaOpts = sys.props.get(javaOptionsConf) + val extraJavaOptsConf = "spark.driver.extraJavaOptions" + val extraJavaOpts = sys.props.get(extraJavaOptsConf) + .map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ - driverArgs.driverOptions, env, classPathEntries, libraryPathEntries, javaOpts) + driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts) val driverDescription = new DriverDescription( driverArgs.jarUrl, @@ -109,6 +107,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends // Exception, if present statusResponse.exception.map { e => println(s"Exception from cluster was: $e") + e.printStackTrace() System.exit(-1) } System.exit(0) @@ -141,8 +140,10 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends */ object Client { def main(args: Array[String]) { - println("WARNING: This client is deprecated and will be removed in a future version of Spark.") - println("Use ./bin/spark-submit with \"--master spark://host:port\"") + if (!sys.props.contains("SPARK_SUBMIT")) { + println("WARNING: This client is deprecated and will be removed in a future version of Spark") + println("Use ./bin/spark-submit with \"--master spark://host:port\"") + } val conf = new SparkConf() val driverArgs = new ClientArguments(args) diff --git a/core/src/main/scala/org/apache/spark/deploy/Command.scala b/core/src/main/scala/org/apache/spark/deploy/Command.scala index 32f3ba385084f..a2b263544c6a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Command.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Command.scala @@ -25,5 +25,5 @@ private[spark] case class Command( environment: Map[String, String], classPathEntries: Seq[String], libraryPathEntries: Seq[String], - extraJavaOptions: Option[String] = None) { + javaOpts: Seq[String]) { } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3b5642b6caa36..318509a67a36f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -46,6 +46,10 @@ object SparkSubmit { private val CLUSTER = 2 private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + // A special jar name that indicates the class being run is inside of Spark itself, and therefore + // no user jar is needed. + private val SPARK_INTERNAL = "spark-internal" + // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" @@ -132,8 +136,6 @@ object SparkSubmit { (clusterManager, deployMode) match { case (MESOS, CLUSTER) => printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") - case (STANDALONE, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Standalone clusters.") case (_, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python applications.") case (_, CLUSTER) if isShell(args.primaryResource) => @@ -166,9 +168,9 @@ object SparkSubmit { val options = List[OptionAssigner]( // All cluster managers - OptionAssigner(args.master, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.master"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.app.name"), - OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), + OptionAssigner(args.jars, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), // Standalone cluster only OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), @@ -182,7 +184,7 @@ object SparkSubmit { OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), // Yarn cluster only - OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name", sysProp = "spark.app.name"), + OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), @@ -199,9 +201,9 @@ object SparkSubmit { sysProp = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraLibraryPath"), - OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, CLIENT, + OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), - OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, CLIENT, + OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.files") @@ -257,21 +259,26 @@ object SparkSubmit { // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (clusterManager == YARN && deployMode == CLUSTER) { childMainClass = "org.apache.spark.deploy.yarn.Client" - childArgs += ("--jar", args.primaryResource) + if (args.primaryResource != SPARK_INTERNAL) { + childArgs += ("--jar", args.primaryResource) + } childArgs += ("--class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) } } } + // Properties given with --conf are superceded by other options, but take precedence over + // properties in the defaults file. + for ((k, v) <- args.sparkProperties) { + sysProps.getOrElseUpdate(k, v) + } + // Read from default spark properties, if any for ((k, v) <- args.getDefaultSparkProperties) { sysProps.getOrElseUpdate(k, v) } - // Spark properties included on command line take precedence - sysProps ++= args.sparkProperties - (childArgs, childClasspath, sysProps, childMainClass) } @@ -332,7 +339,7 @@ object SparkSubmit { * Return whether the given primary resource represents a user jar. */ private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) + !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) } /** @@ -349,6 +356,10 @@ object SparkSubmit { primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL } + private[spark] def isInternal(primaryResource: String): Boolean = { + primaryResource == SPARK_INTERNAL + } + /** * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 3ab67a43a3b55..dd044e6298760 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -58,7 +58,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { val sparkProperties: HashMap[String, String] = new HashMap[String, String]() parseOpts(args.toList) - loadDefaults() + mergeSparkProperties() checkRequiredArguments() /** Return default present in the currently defined defaults file. */ @@ -79,9 +79,11 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { defaultProperties } - /** Fill in any undefined values based on the current properties file or built-in defaults. */ - private def loadDefaults(): Unit = { - + /** + * Fill in any undefined values based on the default properties file or options passed in through + * the '--conf' flag. + */ + private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user if (propertiesFile == null) { sys.env.get("SPARK_HOME").foreach { sparkHome => @@ -94,18 +96,20 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { } } - val defaultProperties = getDefaultSparkProperties + val properties = getDefaultSparkProperties + properties.putAll(sparkProperties) + // Use properties file as fallback for values which have a direct analog to // arguments in this script. - master = Option(master).getOrElse(defaultProperties.get("spark.master").orNull) + master = Option(master).getOrElse(properties.get("spark.master").orNull) executorMemory = Option(executorMemory) - .getOrElse(defaultProperties.get("spark.executor.memory").orNull) + .getOrElse(properties.get("spark.executor.memory").orNull) executorCores = Option(executorCores) - .getOrElse(defaultProperties.get("spark.executor.cores").orNull) + .getOrElse(properties.get("spark.executor.cores").orNull) totalExecutorCores = Option(totalExecutorCores) - .getOrElse(defaultProperties.get("spark.cores.max").orNull) - name = Option(name).getOrElse(defaultProperties.get("spark.app.name").orNull) - jars = Option(jars).getOrElse(defaultProperties.get("spark.jars").orNull) + .getOrElse(properties.get("spark.cores.max").orNull) + name = Option(name).getOrElse(properties.get("spark.app.name").orNull) + jars = Option(jars).getOrElse(properties.get("spark.jars").orNull) // This supports env vars in older versions of Spark master = Option(master).getOrElse(System.getenv("MASTER")) @@ -204,8 +208,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { - // Delineates parsing of Spark options from parsing of user options. var inSparkOpts = true + + // Delineates parsing of Spark options from parsing of user options. parse(opts) def parse(opts: Seq[String]): Unit = opts match { @@ -318,7 +323,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { SparkSubmit.printErrorAndExit(errMessage) case v => primaryResource = - if (!SparkSubmit.isShell(v)) { + if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) { Utils.resolveURI(v).toString } else { v diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index e15a87bd38fda..b8ffa9afb69cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -46,11 +46,11 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription( - "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), - Seq()), Some("dummy-spark-home"), "ignored") + "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), + Seq(), Seq(), Seq()), Some("dummy-spark-home"), "ignored") val listener = new TestListener val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) client.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 01e7065c17b69..6d2d4cef1ee46 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -36,11 +36,11 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis conf.getInt("spark.history.updateInterval", 10)) * 1000 private val logDir = conf.get("spark.history.fs.logDirectory", null) - if (logDir == null) { - throw new IllegalArgumentException("Logging directory must be specified.") - } + private val resolvedLogDir = Option(logDir) + .map { d => Utils.resolveURI(d) } + .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") } - private val fs = Utils.getHadoopFileSystem(logDir) + private val fs = Utils.getHadoopFileSystem(resolvedLogDir) // A timestamp of when the disk was last accessed to check for log updates private var lastLogCheckTimeMs = -1L @@ -76,14 +76,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private def initialize() { // Validate the log directory. - val path = new Path(logDir) + val path = new Path(resolvedLogDir) if (!fs.exists(path)) { throw new IllegalArgumentException( - "Logging directory specified does not exist: %s".format(logDir)) + "Logging directory specified does not exist: %s".format(resolvedLogDir)) } if (!fs.getFileStatus(path).isDir) { throw new IllegalArgumentException( - "Logging directory specified is not a directory: %s".format(logDir)) + "Logging directory specified is not a directory: %s".format(resolvedLogDir)) } checkForLogs() @@ -95,15 +95,16 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis override def getAppUI(appId: String): SparkUI = { try { - val appLogDir = fs.getFileStatus(new Path(logDir, appId)) - loadAppInfo(appLogDir, true)._2 + val appLogDir = fs.getFileStatus(new Path(resolvedLogDir.toString, appId)) + val (_, ui) = loadAppInfo(appLogDir, renderUI = true) + ui } catch { case e: FileNotFoundException => null } } override def getConfig(): Map[String, String] = - Map(("Event Log Location" -> logDir)) + Map("Event Log Location" -> resolvedLogDir.toString) /** * Builds the application list based on the current contents of the log directory. @@ -114,14 +115,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis lastLogCheckTimeMs = getMonotonicTimeMs() logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) try { - val logStatus = fs.listStatus(new Path(logDir)) + val logStatus = fs.listStatus(new Path(resolvedLogDir)) val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() - val logInfos = logDirs.filter { - dir => fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE)) + val logInfos = logDirs.filter { dir => + fs.isFile(new Path(dir.getPath, EventLoggingListener.APPLICATION_COMPLETE)) } val currentApps = Map[String, ApplicationHistoryInfo]( - appList.map(app => (app.id -> app)):_*) + appList.map(app => app.id -> app):_*) // For any application that either (i) is not listed or (ii) has changed since the last time // the listing was created (defined by the log dir's modification time), load the app's info. @@ -131,7 +132,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val curr = currentApps.getOrElse(dir.getPath().getName(), null) if (curr == null || curr.lastUpdated < getModificationTime(dir)) { try { - newApps += loadAppInfo(dir, false)._1 + val (app, _) = loadAppInfo(dir, renderUI = false) + newApps += app } catch { case e: Exception => logError(s"Failed to load app info from directory $dir.") } @@ -159,9 +161,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis * @return A 2-tuple `(app info, ui)`. `ui` will be null if `renderUI` is false. */ private def loadAppInfo(logDir: FileStatus, renderUI: Boolean) = { - val elogInfo = EventLoggingListener.parseLoggingInfo(logDir.getPath(), fs) val path = logDir.getPath val appId = path.getName + val elogInfo = EventLoggingListener.parseLoggingInfo(path, fs) val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec) val appListener = new ApplicationEventListener replayBus.addListener(appListener) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index d7a3e3f120e67..c4ef8b63b0071 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -45,7 +45,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
    - { providerConfig.map(e =>
  • {e._1}: {e._2}
  • ) } + {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
{ if (allApps.size > 0) { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index cacb9da8c947b..d1a64c1912cb8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -25,9 +25,9 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.ui.{WebUI, SparkUI, UIUtils} +import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.SignalLogger /** * A web server that renders SparkUIs of completed applications. @@ -177,7 +177,7 @@ object HistoryServer extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) initSecurity() - val args = new HistoryServerArguments(conf, argStrings) + new HistoryServerArguments(conf, argStrings) val securityManager = new SecurityManager(conf) val providerName = conf.getOption("spark.history.provider") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index be9361b754fc3..25fc76c23e0fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.history import org.apache.spark.SparkConf -import org.apache.spark.util.Utils /** * Command-line parser for the master. @@ -32,6 +31,7 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] args match { case ("--dir" | "-d") :: value :: tail => logDir = value + conf.set("spark.history.fs.logDirectory", value) parse(tail) case ("--help" | "-h") :: tail => @@ -42,9 +42,6 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] case _ => printUsageAndExit(1) } - if (logDir != null) { - conf.set("spark.history.fs.logDirectory", logDir) - } } private def printUsageAndExit(exitCode: Int) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 21f8667819c44..a70ecdb375373 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -154,6 +154,8 @@ private[spark] class Master( } override def postStop() { + masterMetricsSystem.report() + applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { recoveryCompletionTask.cancel() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 4af5bc3afad6c..687e492a0d6fc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -47,7 +47,6 @@ object CommandUtils extends Logging { */ def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") - val extraOpts = command.extraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq()) // Exists for backwards compatibility with older Spark versions val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString) @@ -62,7 +61,7 @@ object CommandUtils extends Logging { val joined = command.libraryPathEntries.mkString(File.pathSeparator) Seq(s"-Djava.library.path=$joined") } else { - Seq() + Seq() } val permGenOpt = Seq("-XX:MaxPermSize=128m") @@ -71,11 +70,11 @@ object CommandUtils extends Logging { val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" val classPath = Utils.executeAndGetOutput( Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment=command.environment) + extraEnvironment = command.environment) val userClassPath = command.classPathEntries ++ Seq(classPath) Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ - permGenOpt ++ libraryOpts ++ extraOpts ++ workerLocalOpts ++ memoryOpts + permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 662d37871e7a6..5caaf6bea3575 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -36,6 +36,7 @@ import org.apache.spark.deploy.master.DriverState.DriverState /** * Manages the execution of one driver, including automatically restarting the driver on failure. + * This is currently only used in standalone cluster deploy mode. */ private[spark] class DriverRunner( val driverId: String, @@ -81,7 +82,7 @@ private[spark] class DriverRunner( driverDesc.command.environment, classPath, driverDesc.command.libraryPathEntries, - driverDesc.command.extraJavaOptions) + driverDesc.command.javaOpts) val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem, sparkHome.getAbsolutePath) launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 467317dd9b44c..7be89f9aff0f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.logging.FileAppender /** * Manages the execution of one executor process. + * This is currently only used in standalone mode. */ private[spark] class ExecutorRunner( val appId: String, @@ -72,7 +73,7 @@ private[spark] class ExecutorRunner( } /** - * kill executor process, wait for exit and notify worker to update resource status + * Kill executor process, wait for exit and notify worker to update resource status. * * @param message the exception message which caused the executor's death */ @@ -114,10 +115,13 @@ private[spark] class ExecutorRunner( } def getCommandSeq = { - val command = Command(appDesc.command.mainClass, - appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), appDesc.command.environment, - appDesc.command.classPathEntries, appDesc.command.libraryPathEntries, - appDesc.command.extraJavaOptions) + val command = Command( + appDesc.command.mainClass, + appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), + appDesc.command.environment, + appDesc.command.classPathEntries, + appDesc.command.libraryPathEntries, + appDesc.command.javaOpts) CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ce425443051b0..fb5252da96519 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -357,6 +357,7 @@ private[spark] class Worker( } override def postStop() { + metricsSystem.report() registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index b389cb546de6c..ecb358c399819 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.worker.ui -import java.io.File import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -25,7 +24,7 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils import org.apache.spark.Logging -import org.apache.spark.util.logging.{FileAppender, RollingFileAppender} +import org.apache.spark.util.logging.RollingFileAppender private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker @@ -64,11 +63,11 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val (logDir, params) = (appId, executorId, driverId) match { + val (logDir, params, pageName) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => - (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e") + (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e", s"$a/$e") case (None, None, Some(d)) => - (s"${workDir.getPath}/$d/", s"driverId=$d") + (s"${workDir.getPath}/$d/", s"driverId=$d", d) case _ => throw new Exception("Request must specify either application or driver identifiers") } @@ -120,7 +119,7 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w
- UIUtils.basicSparkPage(content, logType + " log page for " + appId.getOrElse("unknown app")) + UIUtils.basicSparkPage(content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b455c9fcf4bd6..af736de405397 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -88,6 +88,7 @@ private[spark] class CoarseGrainedExecutorBackend( case StopExecutor => logInfo("Driver commanded a shutdown") + executor.stop() context.stop(self) context.system.shutdown() } @@ -98,8 +99,13 @@ private[spark] class CoarseGrainedExecutorBackend( } private[spark] object CoarseGrainedExecutorBackend extends Logging { - def run(driverUrl: String, executorId: String, hostname: String, cores: Int, - workerUrl: Option[String]) { + + private def run( + driverUrl: String, + executorId: String, + hostname: String, + cores: Int, + workerUrl: Option[String]) { SignalLogger.register(log) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3b69bc4ca4142..1bb1b4aae91bb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import java.util.concurrent._ import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark._ import org.apache.spark.scheduler._ @@ -48,6 +48,8 @@ private[spark] class Executor( private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) + @volatile private var isStopped = false + // No ip or host:port - just hostname Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -107,6 +109,8 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + startDriverHeartbeater() + def launchTask( context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { val tr = new TaskRunner(context, taskId, taskName, serializedTask) @@ -121,6 +125,12 @@ private[spark] class Executor( } } + def stop() { + env.metricsSystem.report() + isStopped = true + threadPool.shutdown() + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -137,11 +147,12 @@ private[spark] class Executor( } class TaskRunner( - execBackend: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) + execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) extends Runnable { @volatile private var killed = false - @volatile private var task: Task[Any] = _ + @volatile var task: Task[Any] = _ + @volatile var attemptedTask: Option[Task[Any]] = None def kill(interruptThread: Boolean) { logInfo(s"Executor is trying to kill $taskName (TID $taskId)") @@ -158,7 +169,6 @@ private[spark] class Executor( val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum val startGCTime = gcTime @@ -200,7 +210,6 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.hostname = Utils.localHostName() m.executorDeserializeTime = taskStart - startTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime @@ -350,4 +359,42 @@ private[spark] class Executor( } } } + + def startDriverHeartbeater() { + val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) + val timeout = AkkaUtils.lookupTimeout(conf) + val retryAttempts = AkkaUtils.numRetries(conf) + val retryIntervalMs = AkkaUtils.retryWaitMs(conf) + val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + + val t = new Thread() { + override def run() { + // Sleep a random interval so the heartbeats don't end up in sync + Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) + + while (!isStopped) { + val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + for (taskRunner <- runningTasks.values()) { + if (!taskRunner.attemptedTask.isEmpty) { + Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + tasksMetrics += ((taskRunner.taskId, metrics)) + } + } + } + + val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + Thread.sleep(interval) + } + } + } + t.setDaemon(true) + t.setName("Driver Heartbeater") + t.start() + } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 21fe643b8d71f..56cd8723a3a22 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -23,6 +23,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus} /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. + * + * This class is used to house metrics both for in-progress and completed tasks. In executors, + * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread + * reads it to send in-progress metrics, and the task thread reads it to send metrics along with + * the completed task. + * + * So, when adding new fields, take into consideration that the whole object can be serialized for + * shipping off at any time to consumers of the SparkListener interface. */ @DeveloperApi class TaskMetrics extends Serializable { @@ -143,7 +151,7 @@ class ShuffleReadMetrics extends Serializable { /** * Absolute time when this task finished reading shuffle data */ - var shuffleFinishTime: Long = _ + var shuffleFinishTime: Long = -1 /** * Number of blocks fetched in this shuffle by this task (remote or local) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 651511da1b7fe..6ef817d0e587e 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -91,6 +91,10 @@ private[spark] class MetricsSystem private (val instance: String, sinks.foreach(_.stop) } + def report(): Unit = { + sinks.foreach(_.report()) + } + def registerSource(source: Source) { sources += source try { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 05852f1f98993..81b9056b40fb8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -57,5 +57,9 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 542dce65366b2..9d5f2ae9328ad 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -66,5 +66,9 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index aeb4ad44a0647..d7b5f5c40efae 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -81,4 +81,8 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index ed27234b4e760..2588fe2c9edb8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -35,4 +35,6 @@ private[spark] class JmxSink(val property: Properties, val registry: MetricRegis reporter.stop() } + override def report() { } + } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 571539ba5e467..2f65bc8b46609 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -57,4 +57,6 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr override def start() { } override def stop() { } + + override def report() { } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 6f2b5a06027ea..0d83d8c425ca4 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -20,4 +20,5 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { def start: Unit def stop: Unit + def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 6388ef82cc5db..fabb882cdd4b3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -17,10 +17,11 @@ package org.apache.spark.rdd +import scala.language.existentials + import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.language.existentials import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -157,8 +158,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled - context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled + context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index e7221e3032c11..11ebafbf6d457 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -49,8 +49,8 @@ private[spark] case class CoalescedRDDPartition( } /** - * Computes how many of the parents partitions have getPreferredLocation - * as one of their preferredLocations + * Computes the fraction of the parents' partitions containing preferredLocation within + * their getPreferredLocs. * @return locality of this coalesced partition between 0 and 1 */ def localFraction: Double = { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e521612ffc27c..8d92ea01d9a3f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -20,7 +20,9 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date import java.io.EOFException + import scala.collection.immutable.Map +import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred.FileSplit @@ -39,6 +41,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.NextIterator /** @@ -232,6 +235,14 @@ class HadoopRDD[K, V]( new InterruptibleIterator[(K, V)](context, iter) } + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new HadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + override def getPreferredLocations(split: Partition): Seq[String] = { // TODO: Filtering out "localhost" in case of file:// URLs val hadoopSplit = split.asInstanceOf[HadoopPartition] @@ -272,4 +283,25 @@ private[spark] object HadoopRDD { conf.setInt("mapred.task.partition", splitId) conf.set("mapred.job.id", jobID.toString) } + + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class HadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = { + val partition = split.asInstanceOf[HadoopPartition] + val inputSplit = partition.inputSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f2b3a64bf1345..7dfec9a18ec67 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -32,6 +34,7 @@ import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD private[spark] class NewHadoopPartition( rddId: Int, @@ -157,6 +160,14 @@ class NewHadoopRDD[K, V]( new InterruptibleIterator(context, iter) } + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + override def getPreferredLocations(split: Partition): Seq[String] = { val theSplit = split.asInstanceOf[NewHadoopPartition] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") @@ -165,6 +176,29 @@ class NewHadoopRDD[K, V]( def getConf: Configuration = confBroadcast.value.value } +private[spark] object NewHadoopRDD { + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = { + val partition = split.asInstanceOf[NewHadoopPartition] + val inputSplit = partition.serializableHadoopSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } +} + private[spark] class WholeTextFileRDD( sc : SparkContext, inputFormatClass: Class[_ <: WholeTextFileInputFormat], diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index afd7075f686b9..e98bad2026e32 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag import org.apache.spark.{Logging, RangePartitioner} +import org.apache.spark.annotation.DeveloperApi /** * Extra functions available on RDDs of (key, value) pairs where the key is sortable through @@ -43,10 +44,10 @@ import org.apache.spark.{Logging, RangePartitioner} */ class OrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag, - P <: Product2[K, V] : ClassTag]( + P <: Product2[K, V] : ClassTag] @DeveloperApi() ( self: RDD[P]) - extends Logging with Serializable { - + extends Logging with Serializable +{ private val ordering = implicitly[Ordering[K]] /** @@ -55,15 +56,12 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in * order of the keys). */ - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { + // TODO: this currently doesn't work on P other than Tuple2! + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + : RDD[(K, V)] = + { val part = new RangePartitioner(numPartitions, self, ascending) - new ShuffledRDD[K, V, V, P](self, part) - .setKeyOrdering(ordering) - .setSortOrder(if (ascending) SortOrder.ASCENDING else SortOrder.DESCENDING) + new ShuffledRDD[K, V, V](self, part) + .setKeyOrdering(if (ascending) ordering else ordering.reverse) } } - -private[spark] object SortOrder extends Enumeration { - type SortOrder = Value - val ASCENDING, DESCENDING = Value -} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c04d162a39616..93af50c0a9cd1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,12 +19,10 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.Date -import java.util.{HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap} +import scala.collection.{Map, mutable} import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -34,19 +32,19 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob, +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.spark._ -import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.SparkHadoopWriter import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer +import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.random.StratifiedSamplingUtils /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -92,7 +90,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) } else { - new ShuffledRDD[K, V, C, (K, C)](self, partitioner) + new ShuffledRDD[K, V, C](self, partitioner) .setSerializer(serializer) .setAggregator(aggregator) .setMapSideCombine(mapSideCombine) @@ -195,6 +193,41 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) foldByKey(zeroValue, defaultPartitioner(self))(func) } + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use + * additional passes over the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling + * without replacement, we need one additional pass over the RDD to guarantee sample size; + * when sampling with replacement, we need two additional passes. + * + * @param withReplacement whether to sample with or without replacement + * @param fractions map of specific keys to sampling rates + * @param seed seed for the random number generator + * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key + * @return RDD containing the sampled subset + */ + def sampleByKey(withReplacement: Boolean, + fractions: Map[K, Double], + exact: Boolean = false, + seed: Long = Utils.random.nextLong): RDD[(K, V)]= { + + require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") + + val samplingFunc = if (withReplacement) { + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed) + } else { + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed) + } + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a @@ -392,7 +425,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (self.partitioner == Some(partitioner)) { self } else { - new ShuffledRDD[K, V, V, (K, V)](self, partitioner) + new ShuffledRDD[K, V, V](self, partitioner) } } @@ -531,6 +564,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. + * + * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only + * one value per key is preserved in the map returned) */ def collectAsMap(): Map[K, V] = { val data = self.collect() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a6abc49c5359e..e1c49e35abecd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -331,7 +332,7 @@ abstract class RDD[T: ClassTag]( val distributePartition = (index: Int, items: Iterator[T]) => { var position = (new Random(index)).nextInt(numPartitions) items.map { t => - // Note that the hash code of the key will just be the key itself. The HashPartitioner + // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. position = position + 1 (position, t) @@ -340,7 +341,7 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition), + new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition), new HashPartitioner(numPartitions)), numPartitions).values } else { @@ -351,8 +352,8 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, - fraction: Double, + def sample(withReplacement: Boolean, + fraction: Double, seed: Long = Utils.random.nextLong): RDD[T] = { require(fraction >= 0.0, "Negative fraction value: " + fraction) if (withReplacement) { @@ -1206,16 +1207,12 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed: Boolean = { - checkpointData.map(_.isCheckpointed).getOrElse(false) - } + def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile: Option[String] = { - checkpointData.flatMap(_.getCheckpointFile) - } + def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) // ======================================================================= // Other internal methods and fields @@ -1239,6 +1236,23 @@ abstract class RDD[T: ClassTag]( /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context = sc + /** + * Private API for changing an RDD's ClassTag. + * Used for internal Java <-> Scala API compatibility. + */ + private[spark] def retag(cls: Class[T]): RDD[T] = { + val classTag: ClassTag[T] = ClassTag.apply(cls) + this.retag(classTag) + } + + /** + * Private API for changing an RDD's ClassTag. + * Used for internal Java <-> Scala API compatibility. + */ + private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = { + this.mapPartitions(identity, preservesPartitioning = true)(classTag) + } + // Avoid handling doCheckpoint multiple times to prevent excessive recursion @transient private var doCheckpointCalled = false diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index c3b2a33fb54d0..f67e5f1857979 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed - RDDCheckpointData.clearTaskCaches() } logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -private[spark] object RDDCheckpointData { - def clearTaskCaches() { - ShuffleMapTask.clearCache() - ResultTask.clearCache() - } -} +// Used for synchronization +private[spark] object RDDCheckpointData diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index da4a8c3dc22b1..d9fe6847254fa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -38,11 +37,12 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * @tparam V the value class. * @tparam C the combiner class. */ +// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs @DeveloperApi -class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( +class ShuffledRDD[K, V, C]( @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) - extends RDD[P](prev.context, Nil) { + extends RDD[(K, C)](prev.context, Nil) { private var serializer: Option[Serializer] = None @@ -52,41 +52,32 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( private var mapSideCombine: Boolean = false - private var sortOrder: Option[SortOrder] = None - /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ - def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { + def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = { this.serializer = Option(serializer) this } /** Set key ordering for RDD's shuffle. */ - def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = { + def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = { this.keyOrdering = Option(keyOrdering) this } /** Set aggregator for RDD's shuffle. */ - def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = { + def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = { this.aggregator = Option(aggregator) this } /** Set mapSideCombine flag for RDD's shuffle. */ - def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = { + def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = { this.mapSideCombine = mapSideCombine this } - /** Set sort order for RDD's sorting. */ - def setSortOrder(sortOrder: SortOrder): ShuffledRDD[K, V, C, P] = { - this.sortOrder = Option(sortOrder) - this - } - override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializer, - keyOrdering, aggregator, mapSideCombine, sortOrder)) + List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } override val partitioner = Some(part) @@ -95,11 +86,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } - override def compute(split: Partition, context: TaskContext): Iterator[P] = { + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() - .asInstanceOf[Iterator[P]] + .asInstanceOf[Iterator[(K, C)]] } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 21c6e07d69f90..197167ecad0bd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -25,21 +25,32 @@ import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.annotation.DeveloperApi -private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int) +/** + * Partition for UnionRDD. + * + * @param idx index of the partition + * @param rdd the parent RDD this partition refers to + * @param parentRddIndex index of the parent RDD this partition refers to + * @param parentRddPartitionIndex index of the partition within the parent RDD + * this partition refers to + */ +private[spark] class UnionPartition[T: ClassTag]( + idx: Int, + @transient rdd: RDD[T], + val parentRddIndex: Int, + @transient parentRddPartitionIndex: Int) extends Partition { - var split: Partition = rdd.partitions(splitIndex) - - def iterator(context: TaskContext) = rdd.iterator(split, context) + var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) - def preferredLocations() = rdd.preferredLocations(split) + def preferredLocations() = rdd.preferredLocations(parentPartition) override val index: Int = idx @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - split = rdd.partitions(splitIndex) + parentPartition = rdd.partitions(parentRddPartitionIndex) oos.defaultWriteObject() } } @@ -47,14 +58,14 @@ private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitInd @DeveloperApi class UnionRDD[T: ClassTag]( sc: SparkContext, - @transient var rdds: Seq[RDD[T]]) + var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { val array = new Array[Partition](rdds.map(_.partitions.size).sum) var pos = 0 - for (rdd <- rdds; split <- rdd.partitions) { - array(pos) = new UnionPartition(pos, rdd, split.index) + for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { + array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) pos += 1 } array @@ -70,9 +81,17 @@ class UnionRDD[T: ClassTag]( deps } - override def compute(s: Partition, context: TaskContext): Iterator[T] = - s.asInstanceOf[UnionPartition[T]].iterator(context) + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + val part = s.asInstanceOf[UnionPartition[T]] + val parentRdd = dependencies(part.parentRddIndex).rdd.asInstanceOf[RDD[T]] + parentRdd.iterator(part.parentPartition, context) + } override def getPreferredLocations(s: Partition): Seq[String] = s.asInstanceOf[UnionPartition[T]].preferredLocations() + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dc6142ab79d03..d87c3048985fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -17,11 +17,11 @@ package org.apache.spark.scheduler -import java.io.{NotSerializableException, PrintWriter, StringWriter} +import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -29,17 +29,18 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import akka.actor._ -import akka.actor.OneForOneStrategy import akka.actor.SupervisorStrategy.Stop import akka.pattern.ask import akka.util.Timeout import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} +import org.apache.spark.storage._ import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -114,6 +115,10 @@ class DAGScheduler( private val dagSchedulerActorSupervisor = env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) + // A closure serializer that we reuse. + // This is only safe because DAGScheduler runs in a single thread. + private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + private[scheduler] var eventProcessActor: ActorRef = _ private def initializeEventProcessActor() { @@ -149,6 +154,23 @@ class DAGScheduler( eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) } + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics) + blockManagerId: BlockManagerId): Boolean = { + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) + implicit val timeout = Timeout(600 seconds) + + Await.result( + blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), + timeout.duration).asInstanceOf[Boolean] + } + // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { eventProcessActor ! ExecutorLost(execId) @@ -189,11 +211,15 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => + // We are going to register ancestor shuffle dependencies + registerShuffleDependencies(shuffleDep, jobId) + // Then register current shuffleDep val stage = newOrUsedStage( shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, shuffleDep.rdd.creationSite) shuffleToMapStage(shuffleDep.shuffleId) = stage + stage } } @@ -258,6 +284,9 @@ class DAGScheduler( private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(r: RDD[_]) { if (!visited(r)) { visited += r @@ -268,18 +297,69 @@ class DAGScheduler( case shufDep: ShuffleDependency[_, _, _] => parents += getShuffleMapStage(shufDep, jobId) case _ => - visit(dep.rdd) + waitingForVisit.push(dep.rdd) } } } } - visit(rdd) + waitingForVisit.push(rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } parents.toList } + // Find ancestor missing shuffle dependencies and register into shuffleToMapStage + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) + while (!parentsWithNoMapStage.isEmpty) { + val currentShufDep = parentsWithNoMapStage.pop() + val stage = + newOrUsedStage( + currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, + currentShufDep.rdd.creationSite) + shuffleToMapStage(currentShufDep.shuffleId) = stage + } + } + + // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { + val parents = new Stack[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] + def visit(r: RDD[_]) { + if (!visited(r)) { + visited += r + for (dep <- r.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + if (!shuffleToMapStage.contains(shufDep.shuffleId)) { + parents.push(shufDep) + } + + waitingForVisit.push(shufDep.rdd) + case _ => + waitingForVisit.push(dep.rdd) + } + } + } + } + + waitingForVisit.push(rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } + parents + } + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd @@ -292,13 +372,16 @@ class DAGScheduler( missing += mapStage } case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) + waitingForVisit.push(narrowDep.rdd) } } } } } - visit(stage.rdd) + waitingForVisit.push(stage.rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } missing.toList } @@ -361,9 +444,6 @@ class DAGScheduler( // data structures based on StageId stageIdToStage -= stageId - ShuffleMapTask.removeStage(stageId) - ResultTask.removeStage(stageId) - logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -691,49 +771,83 @@ class DAGScheduler( } } - /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() + + val properties = if (jobIdToActiveJob.contains(jobId)) { + jobIdToActiveJob(stage.jobId).properties + } else { + // this stage will be assigned to "default" pool + null + } + + runningStages += stage + // SparkListenerStageSubmitted should be posted before testing whether tasks are + // serializable. If tasks are not serializable, a SparkListenerStageCompleted event + // will be posted, which should always come after a corresponding SparkListenerStageSubmitted + // event. + listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) + + // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. + // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast + // the serialized copy of the RDD and for each task we will deserialize it, which means each + // task gets a different copy of the RDD. This provides stronger isolation between tasks that + // might modify state of objects referenced in their closures. This is necessary in Hadoop + // where the JobConf/Configuration object is not thread-safe. + var taskBinary: Broadcast[Array[Byte]] = null + try { + // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). + // For ResultTask, serialize and broadcast (rdd, func). + val taskBinaryBytes: Array[Byte] = + if (stage.isShuffleMap) { + closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() + } else { + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) + } catch { + // In the case of a failure during serialization, abort the stage. + case e: NotSerializableException => + abortStage(stage, "Task not serializable: " + e.toString) + runningStages -= stage + return + case NonFatal(e) => + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return + } + if (stage.isShuffleMap) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + val part = stage.rdd.partitions(p) + tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs) } } else { // This is a final stage; figure out its job's missing partitions val job = stage.resultOfJob.get for (id <- 0 until job.numPartitions if !job.finished(id)) { - val partition = job.partitions(id) - val locs = getPreferredLocs(stage.rdd, partition) - tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + tasks += new ResultTask(stage.id, taskBinary, part, locs, id) } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null - } - if (tasks.size > 0) { - runningStages += stage - // SparkListenerStageSubmitted should be posted before testing whether tasks are - // serializable. If tasks are not serializable, a SparkListenerStageCompleted event - // will be posted, which should always come after a corresponding SparkListenerStageSubmitted - // event. - listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) - // Preemptively serialize a task to make sure it can be serialized. We are catching this // exception here because it would be fairly hard to catch the non-serializable exception // down the road, where we have several different implementations for local scheduler and // cluster schedulers. + // + // We've already serialized RDDs and closures in taskBinary, but here we check for all other + // objects such as Partition. try { - SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) + closureSerializer.serialize(tasks.head) } catch { case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) @@ -752,6 +866,9 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.info.submissionTime = Some(clock.getTime()) } else { + // Because we posted SparkListenerStageSubmitted earlier, we should post + // SparkListenerStageCompleted here in case there are no tasks to run. + listenerBus.post(SparkListenerStageCompleted(stage.info)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) runningStages -= stage @@ -1063,6 +1180,9 @@ class DAGScheduler( } val visitedRdds = new HashSet[RDD[_]] val visitedStages = new HashSet[Stage] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visitedRdds(rdd)) { visitedRdds += rdd @@ -1072,15 +1192,18 @@ class DAGScheduler( val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { visitedStages += mapStage - visit(mapStage.rdd) + waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) + waitingForVisit.push(narrowDep.rdd) } } } } - visit(stage.rdd) + waitingForVisit.push(stage.rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } visitedRdds.contains(target.rdd) } @@ -1092,6 +1215,22 @@ class DAGScheduler( */ private[spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + getPreferredLocsInternal(rdd, partition, new HashSet) + } + + /** Recursive implementation for getPreferredLocs. */ + private def getPreferredLocsInternal( + rdd: RDD[_], + partition: Int, + visited: HashSet[(RDD[_],Int)]) + : Seq[TaskLocation] = + { + // If the partition has already been visited, no need to re-visit. + // This avoids exponential path exploration. SPARK-695 + if (!visited.add((rdd,partition))) { + // Nil has already been returned for previously visited partitions. + return Nil + } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (!cached.isEmpty) { @@ -1108,7 +1247,7 @@ class DAGScheduler( rdd.dependencies.foreach { case n: NarrowDependency[_] => for (inPart <- n.getParents(partition)) { - val locs = getPreferredLocs(n.rdd, inPart) + val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index ae6ca9f4e7bf5..406147f167bf3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -29,7 +29,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{FileLogger, JsonProtocol} +import org.apache.spark.util.{FileLogger, JsonProtocol, Utils} /** * A SparkListener that logs events to persistent storage. @@ -55,7 +55,7 @@ private[spark] class EventLoggingListener( private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/") private val name = appName.replaceAll("[ :/]", "-").toLowerCase + "-" + System.currentTimeMillis - val logDir = logBaseDir + "/" + name + val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize, shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS)) @@ -215,7 +215,7 @@ private[spark] object EventLoggingListener extends Logging { } catch { case e: Exception => logError("Exception in parsing logging info from directory %s".format(logDir), e) - EventLoggingInfo.empty + EventLoggingInfo.empty } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index bbf9f7388b074..d09fd7aa57642 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,134 +17,56 @@ package org.apache.spark.scheduler -import scala.language.existentials +import java.nio.ByteBuffer import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} - -private[spark] object ResultTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = - { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(func) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = - { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - (rdd, func) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD /** * A task that sends back the output to the driver application. * - * See [[org.apache.spark.scheduler.Task]] for more information. + * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rdd input to func - * @param func a function to apply on a partition of the RDD - * @param _partitionId index of the number in the RDD + * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each + * partition of the given RDD. Once deserialized, the type should be + * (RDD[T], (TaskContext, Iterator[T]) => U). + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). */ private[spark] class ResultTask[T, U]( stageId: Int, - var rdd: RDD[T], - var func: (TaskContext, Iterator[T]) => U, - _partitionId: Int, + taskBinary: Broadcast[Array[Byte]], + partition: Partition, @transient locs: Seq[TaskLocation], - var outputId: Int) - extends Task[U](stageId, _partitionId) with Externalizable { + val outputId: Int) + extends Task[U](stageId, partition.index) with Serializable { - def this() = this(0, null, null, 0, null, 0) - - var split = if (rdd == null) null else rdd.partitions(partitionId) - - @transient private val preferredLocs: Seq[TaskLocation] = { + @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } override def runTask(context: TaskContext): U = { + // Deserialize the RDD and the func using the broadcast variables. + val ser = SparkEnv.get.closureSerializer.newInstance() + val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) try { - func(context, rdd.iterator(split, context)) + func(context, rdd.iterator(partition, context)) } finally { context.executeOnCompleteCallbacks() } } + // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ResultTask.serializeInfo( - stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeInt(outputId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) - rdd = rdd_.asInstanceOf[RDD[T]] - func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partitionId = in.readInt() - outputId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fdaf1de83f051..11255c07469d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,134 +17,55 @@ package org.apache.spark.scheduler -import scala.language.existentials - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import java.nio.ByteBuffer -import scala.collection.mutable.HashMap +import scala.language.existentials import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter -private[spark] object ShuffleMapTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(dep) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] - (rdd, dep) - } - - // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) - val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - HashMap(set.toSeq: _*) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - /** - * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner - * specified in the ShuffleDependency). - * - * See [[org.apache.spark.scheduler.Task]] for more information. - * +* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner +* specified in the ShuffleDependency). +* +* See [[org.apache.spark.scheduler.Task]] for more information. +* * @param stageId id of the stage this task belongs to - * @param rdd the final RDD in this stage - * @param dep the ShuffleDependency - * @param _partitionId index of the number in the RDD + * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * the type should be (RDD[_], ShuffleDependency[_, _, _]). + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - var rdd: RDD[_], - var dep: ShuffleDependency[_, _, _], - _partitionId: Int, + taskBinary: Broadcast[Array[Byte]], + partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, _partitionId) - with Externalizable - with Logging { + extends Task[MapStatus](stageId, partition.index) with Logging { - protected def this() = this(0, null, null, 0, null) + /** A constructor used only in test suites. This does not require passing in an RDD. */ + def this(partitionId: Int) { + this(0, null, new Partition { override def index = 0 }, null) + } @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partitionId) - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) - rdd = rdd_ - dep = dep_ - partitionId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } - override def runTask(context: TaskContext): MapStatus = { + // Deserialize the RDD using the broadcast variable. + val ser = SparkEnv.get.closureSerializer.newInstance() + val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) - writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) + writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) return writer.stop(success = true).get } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 82163eadd56e9..d01d318633877 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -75,6 +75,12 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorMetricsUpdate( + execId: String, + taskMetrics: Seq[(Long, Int, TaskMetrics)]) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String) extends SparkListenerEvent @@ -158,6 +164,11 @@ trait SparkListener { * Called when the application ends */ def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { } + + /** + * Called when the driver receives task metrics from an executor in a heartbeat. + */ + def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index ed9fb24bc8ce8..e79ffd7a3587d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -68,6 +68,8 @@ private[spark] trait SparkListenerBus extends Logging { foreachListener(_.onApplicationStart(applicationStart)) case applicationEnd: SparkListenerApplicationEnd => foreachListener(_.onApplicationEnd(applicationEnd)) + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) case SparkListenerShutdown => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5871edeb856ad..5c5e421404a21 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -26,6 +26,8 @@ import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils + /** * A unit of execution. We have two kinds of Task's in Spark: @@ -44,6 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + context.taskMetrics.hostname = Utils.localHostName(); taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 819c35257b5a7..1a0b877c8a5e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -18,6 +18,8 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId /** * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. @@ -54,4 +56,12 @@ private[spark] trait TaskScheduler { // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int + + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index be3673c48eda8..d2f764fc22f54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -32,6 +32,9 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.util.Utils +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import akka.actor.Props /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. @@ -320,6 +323,26 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + override def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + blockManagerId: BlockManagerId): Boolean = { + val metricsWithStageIds = taskMetrics.flatMap { + case (id, metrics) => { + taskIdToTaskSetId.get(id) + .flatMap(activeTaskSets.get) + .map(_.stageId) + .map(x => (id, x, metrics)) + } + } + dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) + } + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { taskSetManager.handleTaskGettingResult(tid) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bf2dc88e29048..48aaaa54bdb35 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} @@ -46,6 +46,7 @@ private[spark] class SparkDeploySchedulerBackend( CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") + .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp => cp.split(java.io.File.pathSeparator) } @@ -54,9 +55,11 @@ private[spark] class SparkDeploySchedulerBackend( cp.split(java.io.File.pathSeparator) } - val command = Command( - "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, - classPathEntries, libraryPathEntries, extraJavaOpts) + // Start executors with a few necessary configs for registering with the scheduler + val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", + args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) val sparkHome = sc.getSparkHome() val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, sparkHome, sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 5b897597fa285..3d1cf312ccc97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -23,8 +23,9 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.storage.BlockManagerId private case class ReviveOffers() @@ -32,6 +33,8 @@ private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class StopExecutor() + /** * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend @@ -63,6 +66,9 @@ private[spark] class LocalActor( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + + case StopExecutor => + executor.stop() } def reviveOffers() { @@ -91,6 +97,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: } override def stop() { + localActor ! StopExecutor } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index fa79b25759153..e60b802a86a14 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -48,11 +48,12 @@ class KryoSerializer(conf: SparkConf) with Serializable { private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 + private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val registrator = conf.getOption("spark.kryo.registrator") - def newKryoOutput() = new KryoOutput(bufferSize) + def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 5b0940ecce29d..df98d18fa8193 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,7 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 76cdb8f4f8e8a..7c9dc8e5f88ef 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -18,11 +18,11 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} -import org.apache.spark.rdd.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.util.collection.ExternalSorter -class HashShuffleReader[K, C]( +private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, @@ -36,8 +36,8 @@ class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, - Serializer.getSerializer(dep.serializer)) + val ser = Serializer.getSerializer(dep.serializer) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { @@ -48,19 +48,23 @@ class HashShuffleReader[K, C]( } else if (dep.aggregator.isEmpty && dep.mapSideCombine) { throw new IllegalStateException("Aggregator is empty for map-side combine") } else { - iter + // Convert the Product2s to pairs since this is what downstream RDDs currently expect + iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) } - val sortedIter = for (sortOrder <- dep.sortOrder; ordering <- dep.keyOrdering) yield { - val buf = aggregatedIter.toArray - if (sortOrder == SortOrder.ASCENDING) { - buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator - } else { - buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator - } + // Sort the output if there is a sort ordering defined. + dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, + // the ExternalSorter won't spill to disk. + val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + sorter.write(aggregatedIter) + context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + sorter.iterator + case None => + aggregatedIter } - - sortedIter.getOrElse(aggregatedIter) } /** Close this reader */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 9b78228519da4..45d3b8b9b8725 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -class HashShuffleWriter[K, V]( +private[spark] class HashShuffleWriter[K, V]( handle: BaseShuffleHandle[K, V, _], mapId: Int, context: TaskContext) @@ -33,6 +33,10 @@ class HashShuffleWriter[K, V]( private val dep = handle.dependency private val numOutputSplits = dep.partitioner.numPartitions private val metrics = context.taskMetrics + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. private var stopping = false private val blockManager = SparkEnv.get.blockManager @@ -61,7 +65,8 @@ class HashShuffleWriter[K, V]( } /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { + override def stop(initiallySuccess: Boolean): Option[MapStatus] = { + var success = initiallySuccess try { if (stopping) { return None @@ -69,15 +74,16 @@ class HashShuffleWriter[K, V]( stopping = true if (success) { try { - return Some(commitWritesAndBuildStatus()) + Some(commitWritesAndBuildStatus()) } catch { case e: Exception => + success = false revertWrites() throw e } } else { revertWrites() - return None + None } } finally { // Release the writers back to the shuffle block manager. @@ -96,8 +102,7 @@ class HashShuffleWriter[K, V]( var totalBytes = 0L var totalTime = 0L val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() + writer.commitAndClose() val size = writer.fileSegment().length totalBytes += size totalTime += writer.timeWriting() @@ -116,8 +121,7 @@ class HashShuffleWriter[K, V]( private def revertWrites(): Unit = { if (shuffle != null && shuffle.writers != null) { for (writer <- shuffle.writers) { - writer.revertPartialWrites() - writer.close() + writer.revertPartialWritesAndClose() } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala new file mode 100644 index 0000000000000..6dcca47ea7c0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{DataInputStream, FileInputStream} + +import org.apache.spark.shuffle._ +import org.apache.spark.{TaskContext, ShuffleDependency} +import org.apache.spark.shuffle.hash.HashShuffleReader +import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId} + +private[spark] class SortShuffleManager extends ShuffleManager { + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + // We currently use the same block store shuffle fetcher as the hash-based shuffle. + new HashShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) + : ShuffleWriter[K, V] = { + new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Unit = {} + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = {} + + /** Get the location of a block in a map output file. Uses the index file we create for it. */ + def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = { + // The block is actually going to be a range of a single map output file for this map, so + // figure out the ID of the consolidated file, then the offset within that from our index + val consolidatedId = blockId.copy(reduceId = 0) + val indexFile = diskManager.getFile(consolidatedId.name + ".index") + val in = new DataInputStream(new FileInputStream(indexFile)) + try { + in.skip(blockId.reduceId * 8) + val offset = in.readLong() + val nextOffset = in.readLong() + new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset) + } finally { + in.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala new file mode 100644 index 0000000000000..9a356d0dbaf17 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{BufferedOutputStream, File, FileOutputStream, DataOutputStream} + +import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter + +private[spark] class SortShuffleWriter[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Int, + context: TaskContext) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency + private val numPartitions = dep.partitioner.numPartitions + + private val blockManager = SparkEnv.get.blockManager + private val ser = Serializer.getSerializer(dep.serializer.orNull) + + private val conf = SparkEnv.get.conf + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + + private var sorter: ExternalSorter[K, V, _] = null + private var outputFile: File = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = null + + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + // Get an iterator with the elements for each partition ID + val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = { + if (dep.mapSideCombine) { + if (!dep.aggregator.isDefined) { + throw new IllegalStateException("Aggregator is empty for map-side combine") + } + sorter = new ExternalSorter[K, V, C]( + dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + sorter.write(records) + sorter.partitionedIterator + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we + // don't care whether the keys get sorted in each partition; that will be done on the + // reduce side if the operation being run is sortByKey. + sorter = new ExternalSorter[K, V, V]( + None, Some(dep.partitioner), None, dep.serializer) + sorter.write(records) + sorter.partitionedIterator + } + } + + // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later + // serve different ranges of this file using an index file that we create at the end. + val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0) + outputFile = blockManager.diskBlockManager.getFile(blockId) + + // Track location of each range in the output file + val offsets = new Array[Long](numPartitions + 1) + val lengths = new Array[Long](numPartitions) + + // Statistics + var totalBytes = 0L + var totalTime = 0L + + for ((id, elements) <- partitions) { + if (elements.hasNext) { + val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize) + for (elem <- elements) { + writer.write(elem) + } + writer.commitAndClose() + val segment = writer.fileSegment() + offsets(id + 1) = segment.offset + segment.length + lengths(id) = segment.length + totalTime += writer.timeWriting() + totalBytes += segment.length + } else { + // The partition is empty; don't create a new writer to avoid writing headers, etc + offsets(id + 1) = offsets(id) + } + } + + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime + context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics) + context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + + // Write an index file with the offsets of each block, plus a final offset at the end for the + // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure + // out where each block begins and ends. + + val diskBlockManager = blockManager.diskBlockManager + val indexFile = diskBlockManager.getFile(blockId.name + ".index") + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + try { + var i = 0 + while (i < numPartitions + 1) { + out.writeLong(offsets(i)) + i += 1 + } + } finally { + out.close() + } + + // Register our map output with the ShuffleBlockManager, which handles cleaning it over time + blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions) + + mapStatus = new MapStatus(blockManager.blockManagerId, + lengths.map(MapOutputTracker.compressSize)) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + return Option(mapStatus) + } else { + // The map task failed, so delete our output file if we created one + if (outputFile != null) { + outputFile.delete() + } + return None + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + sorter.stop() + sorter = null + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 69905a960a2ca..ccf830e118ee7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -200,14 +200,17 @@ object BlockFetcherIterator { // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight for (id <- localBlocksToFetch) { - getLocalFromDisk(id, serializer) match { - case Some(iter) => { - // Pass 0 as size since it's not in flight - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") + try { + // getLocalFromDisk never return None but throws BlockException + val iter = getLocalFromDisk(id, serializer).get + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) + logDebug("Got local block " + id) + } catch { + case e: Exception => { + logError(s"Error occurred while fetching local blocks", e) + results.put(new FetchResult(id, -1, null)) + return } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 42ec181b00bb3..c1756ac905417 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -54,11 +54,15 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { } @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) - extends BlockId { +case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +@DeveloperApi +case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { + def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" +} + @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) @@ -88,6 +92,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -99,6 +104,8 @@ object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => + ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d746526639e58..c0a06017945f0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -116,15 +116,6 @@ private[spark] class BlockManager( private var asyncReregisterTask: Future[Unit] = null private val asyncReregisterLock = new Object - private def heartBeat(): Unit = { - if (!master.sendHeartBeat(blockManagerId)) { - reregister() - } - } - - private val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - private var heartBeatTask: Cancellable = null - private val metadataCleaner = new MetadataCleaner( MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) private val broadcastCleaner = new MetadataCleaner( @@ -161,11 +152,6 @@ private[spark] class BlockManager( private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) - if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { - heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { - Utils.tryOrExit { heartBeat() } - } - } } /** @@ -195,7 +181,7 @@ private[spark] class BlockManager( * * Note that this method must be called without any BlockInfo locks held. */ - private def reregister(): Unit = { + def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") master.registerBlockManager(blockManagerId, maxMemory, slaveActor) @@ -1065,9 +1051,6 @@ private[spark] class BlockManager( } def stop(): Unit = { - if (heartBeatTask != null) { - heartBeatTask.cancel() - } connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() @@ -1095,12 +1078,6 @@ private[spark] object BlockManager extends Logging { (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } - def getHeartBeatFrequency(conf: SparkConf): Long = - conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) / 4 - - def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean = - conf.getBoolean("spark.test.disableBlockManagerHeartBeat", false) - /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that * might cause errors if one attempts to read from the unmapped buffer, but it's better than diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7897fade2df2b..669307765d1fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -21,7 +21,6 @@ import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global import akka.actor._ -import akka.pattern.ask import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ @@ -29,8 +28,8 @@ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging { - val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3) - val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000) + private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) + private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" @@ -42,15 +41,6 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log logInfo("Removed " + execId + " successfully in removeExecutor") } - /** - * Send the driver actor a heart beat from the slave. Returns true if everything works out, - * false if the driver does not know about the given block manager, which means the block - * manager should re-register. - */ - def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askDriverWithReply[Boolean](HeartBeat(blockManagerId)) - } - /** Register the BlockManager's id with the driver. */ def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { logInfo("Trying to register BlockManager") @@ -223,33 +213,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log * throw a SparkException if this fails. */ private def askDriverWithReply[T](message: Any): T = { - // TODO: Consider removing multiple attempts - if (driverActor == null) { - throw new SparkException("Error sending message to BlockManager as driverActor is null " + - "[message = " + message + "]") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPTS) { - attempts += 1 - try { - val future = driverActor.ask(message)(timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("BlockManagerMaster returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) - } - Thread.sleep(AKKA_RETRY_INTERVAL_MS) - } - - throw new SparkException( - "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) + AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, + timeout) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index de1cc5539fb48..94f5a4bb2e9cd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,25 +52,24 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - val slaveTimeout = conf.get("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequency(conf) * 3)).toLong + val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", + math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) - val checkTimeoutInterval = conf.get("spark.storage.blockManagerTimeoutIntervalMs", - "60000").toLong + val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", + 60000) var timeoutCheckingTask: Cancellable = null override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - } + import context.dispatcher + timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, + checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) super.preStart() } def receive = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + logInfo("received a register") register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -129,8 +128,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case ExpireDeadHosts => expireDeadHosts() - case HeartBeat(blockManagerId) => - sender ! heartBeat(blockManagerId) + case BlockManagerHeartbeat(blockManagerId) => + sender ! heartbeatReceived(blockManagerId) case other => logWarning("Got unknown message: " + other) @@ -216,7 +215,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val minSeenTime = now - slaveTimeout val toRemove = new mutable.HashSet[BlockManagerId] for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime) { + if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "") { logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") toRemove += info.blockManagerId @@ -230,7 +229,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) } - private def heartBeat(blockManagerId: BlockManagerId): Boolean = { + /** + * Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { blockManagerId.executorId == "" && !isLocal } else { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 2b53bf33b5fba..10b65286fb7db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef -private[storage] object BlockManagerMessages { +private[spark] object BlockManagerMessages { ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// @@ -53,8 +53,6 @@ private[storage] object BlockManagerMessages { sender: ActorRef) extends ToBlockManagerMaster - case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: BlockId, @@ -124,5 +122,7 @@ private[storage] object BlockManagerMessages { case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true) extends ToBlockManagerMaster + case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index a2687e6be4e34..01d46e1ffc960 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -39,16 +39,16 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { def isOpen: Boolean /** - * Flush the partial writes and commit them as a single atomic block. Return the - * number of bytes written for this commit. + * Flush the partial writes and commit them as a single atomic block. */ - def commit(): Long + def commitAndClose(): Unit /** * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. */ - def revertPartialWrites() + def revertPartialWritesAndClose() /** * Writes an object. @@ -57,6 +57,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { /** * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. */ def fileSegment(): FileSegment @@ -108,7 +109,7 @@ private[spark] class DiskBlockObjectWriter( private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private val initialPosition = file.length() - private var lastValidPosition = initialPosition + private var finalPosition: Long = -1 private var initialized = false private var _timeWriting = 0L @@ -116,7 +117,6 @@ private[spark] class DiskBlockObjectWriter( fos = new FileOutputStream(file, true) ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() - lastValidPosition = initialPosition bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) initialized = true @@ -147,28 +147,36 @@ private[spark] class DiskBlockObjectWriter( override def isOpen: Boolean = objOut != null - override def commit(): Long = { + override def commitAndClose(): Unit = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. objOut.flush() bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition + close() } + finalPosition = file.length() } - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. + override def revertPartialWritesAndClose() { + try { + if (initialized) { + objOut.flush() + bs.flush() + close() + } + + val truncateStream = new FileOutputStream(file, true) + try { + truncateStream.getChannel.truncate(initialPosition) + } finally { + truncateStream.close() + } + } catch { + case e: Exception => + logError("Uncaught exception while reverting partial writes to file " + file, e) } } @@ -188,6 +196,7 @@ private[spark] class DiskBlockObjectWriter( // Only valid if called after commit() override def bytesWritten: Long = { - lastValidPosition - initialPosition + assert(finalPosition != -1, "bytesWritten is only valid after successful commit()") + finalPosition - initialPosition } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2e7ed7538e6e5..4d66ccea211fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -21,10 +21,11 @@ import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random, UUID} -import org.apache.spark.Logging +import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.network.netty.{PathResolver, ShuffleSender} import org.apache.spark.util.Utils +import org.apache.spark.shuffle.sort.SortShuffleManager /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -34,11 +35,13 @@ import org.apache.spark.util.Utils * * @param rootDirs The directories to use for storing block files. Data will be hashed among these. */ -private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String) +private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String) extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64) + + private val subDirsPerLocalDir = + shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid @@ -54,13 +57,19 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD addShutdownHook() /** - * Returns the physical file segment in which the given BlockId is located. - * If the BlockId has been mapped to a specific FileSegment, that will be returned. - * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. + * Returns the physical file segment in which the given BlockId is located. If the BlockId has + * been mapped to a specific FileSegment by the shuffle layer, that will be returned. + * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId. */ def getBlockLocation(blockId: BlockId): FileSegment = { - if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) { - shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) + val env = SparkEnv.get // NOTE: can be null in unit tests + if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) { + // For sort-based shuffle, let it figure out its blocks + val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager] + sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this) + } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) { + // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this + shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) } else { val file = getFile(blockId.name) new FileSegment(file, 0, file.length()) @@ -99,13 +108,18 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD getBlockLocation(blockId).file.exists() } - /** List all the blocks currently stored on disk by the disk manager. */ - def getAllBlocks(): Seq[BlockId] = { + /** List all the files currently stored on disk by the disk manager. */ + def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories subDirs.flatten.filter(_ != null).flatMap { dir => - val files = dir.list() + val files = dir.listFiles() if (files != null) files else Seq.empty - }.map(BlockId.apply) + } + } + + /** List all the blocks currently stored on disk by the disk manager. */ + def getAllBlocks(): Seq[BlockId] = { + getAllFiles().map(f => BlockId(f.getName)) } /** Produces a unique block id and File suitable for intermediate results. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 35910e552fe86..28aa35bc7e147 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.shuffle.sort.SortShuffleManager /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -58,6 +59,7 @@ private[spark] trait ShuffleWriterGroup { * each block stored in each file. In order to find the location of a shuffle block, we search the * files within a ShuffleFileGroups associated with the block's reducer. */ +// TODO: Factor this into a separate class for each ShuffleManager implementation private[spark] class ShuffleBlockManager(blockManager: BlockManager) extends Logging { def conf = blockManager.conf @@ -67,6 +69,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) + // Are we using sort-based shuffle? + val sortBasedShuffle = + conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName + private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 /** @@ -91,6 +97,20 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) + /** + * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle + * because it just writes a single file by itself. + */ + def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = { + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + val shuffleState = shuffleStates(shuffleId) + shuffleState.completedMapTasks.add(mapId) + } + + /** + * Get a ShuffleWriterGroup for the given map task, which will register it as complete + * when the writers are closed successfully + */ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) @@ -124,7 +144,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { if (consolidateShuffleFiles) { if (success) { val offsets = writers.map(_.fileSegment().offset) - fileGroup.recordMapOutput(mapId, offsets) + val lengths = writers.map(_.fileSegment().length) + fileGroup.recordMapOutput(mapId, offsets, lengths) } recycleFileGroup(fileGroup) } else { @@ -182,7 +203,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { + if (sortBasedShuffle) { + // There's a single block ID for each map, plus an index file for it + for (mapId <- state.completedMapTasks) { + val blockId = new ShuffleBlockId(shuffleId, mapId, 0) + blockManager.diskBlockManager.getFile(blockId).delete() + blockManager.diskBlockManager.getFile(blockId.name + ".index").delete() + } + } else if (consolidateShuffleFiles) { for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { file.delete() } @@ -220,6 +248,8 @@ object ShuffleBlockManager { * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. */ private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { + private var numBlocks: Int = 0 + /** * Stores the absolute index of each mapId in the files of this group. For instance, * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. @@ -227,23 +257,27 @@ object ShuffleBlockManager { private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() /** - * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file. - * This ordering allows us to compute block lengths by examining the following block offset. + * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by + * position in the file. * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every * reducer. */ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { new PrimitiveVector[Long]() } - - def numBlocks = mapIdToIndex.size + private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { + new PrimitiveVector[Long]() + } def apply(bucketId: Int) = files(bucketId) - def recordMapOutput(mapId: Int, offsets: Array[Long]) { + def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { + assert(offsets.length == lengths.length) mapIdToIndex(mapId) = numBlocks + numBlocks += 1 for (i <- 0 until offsets.length) { blockOffsetsByReducer(i) += offsets(i) + blockLengthsByReducer(i) += lengths(i) } } @@ -251,16 +285,11 @@ object ShuffleBlockManager { def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { val file = files(reducerId) val blockOffsets = blockOffsetsByReducer(reducerId) + val blockLengths = blockLengthsByReducer(reducerId) val index = mapIdToIndex.getOrElse(mapId, -1) if (index >= 0) { val offset = blockOffsets(index) - val length = - if (index + 1 < numBlocks) { - blockOffsets(index + 1) - offset - } else { - file.length() - offset - } - assert(length >= 0) + val length = blockLengths(index) Some(new FileSegment(file, offset, length)) } else { None diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index efb527b4f03e6..da2f5d3172fe2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -130,32 +130,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) - // create executor summary map if necessary - val executorSummaryMap = stageData.executorSummary - executorSummaryMap.getOrElseUpdate(key = info.executorId, op = new ExecutorSummary) - - executorSummaryMap.get(info.executorId).foreach { y => - // first update failed-task, succeed-task - taskEnd.reason match { - case Success => - y.succeededTasks += 1 - case _ => - y.failedTasks += 1 - } - - // update duration - y.taskTime += info.duration - - val metrics = taskEnd.taskMetrics - if (metrics != null) { - metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead } - metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } - metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } - y.memoryBytesSpilled += metrics.memoryBytesSpilled - y.diskBytesSpilled += metrics.diskBytesSpilled - } + val execSummaryMap = stageData.executorSummary + val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) + + taskEnd.reason match { + case Success => + execSummary.succeededTasks += 1 + case _ => + execSummary.failedTasks += 1 } - + execSummary.taskTime += info.duration stageData.numActiveTasks -= 1 val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = @@ -171,28 +155,75 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { (Some(e.toErrorString), None) } + if (!metrics.isEmpty) { + val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) + updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics) + } - val taskRunTime = metrics.map(_.executorRunTime).getOrElse(0L) - stageData.executorRunTime += taskRunTime - val inputBytes = metrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L) - stageData.inputBytes += inputBytes - - val shuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L) - stageData.shuffleReadBytes += shuffleRead - - val shuffleWrite = - metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L) - stageData.shuffleWriteBytes += shuffleWrite - - val memoryBytesSpilled = metrics.map(_.memoryBytesSpilled).getOrElse(0L) - stageData.memoryBytesSpilled += memoryBytesSpilled + val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) + taskData.taskInfo = info + taskData.taskMetrics = metrics + taskData.errorMessage = errorMessage + } + } - val diskBytesSpilled = metrics.map(_.diskBytesSpilled).getOrElse(0L) - stageData.diskBytesSpilled += diskBytesSpilled + /** + * Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage + * aggregate metrics by calculating deltas between the currently recorded metrics and the new + * metrics. + */ + def updateAggregateMetrics( + stageData: StageUIData, + execId: String, + taskMetrics: TaskMetrics, + oldMetrics: Option[TaskMetrics]) { + val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) + + val shuffleWriteDelta = + (taskMetrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L)) + stageData.shuffleWriteBytes += shuffleWriteDelta + execSummary.shuffleWrite += shuffleWriteDelta + + val shuffleReadDelta = + (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)) + stageData.shuffleReadBytes += shuffleReadDelta + execSummary.shuffleRead += shuffleReadDelta + + val diskSpillDelta = + taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) + stageData.diskBytesSpilled += diskSpillDelta + execSummary.diskBytesSpilled += diskSpillDelta + + val memorySpillDelta = + taskMetrics.memoryBytesSpilled - oldMetrics.map(_.memoryBytesSpilled).getOrElse(0L) + stageData.memoryBytesSpilled += memorySpillDelta + execSummary.memoryBytesSpilled += memorySpillDelta + + val timeDelta = + taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L) + stageData.executorRunTime += timeDelta + } - stageData.taskData(info.taskId) = new TaskUIData(info, metrics, errorMessage) + override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { + for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + val stageData = stageIdToData.getOrElseUpdate(sid, { + logWarning("Metrics update for task in unknown stage " + sid) + new StageUIData + }) + val taskData = stageData.taskData.get(taskId) + taskData.map { t => + if (!t.taskInfo.finished) { + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics, + t.taskMetrics) + + // Overwrite task metrics + t.taskMetrics = Some(taskMetrics) + } + } } - } // end of onTaskEnd + } override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index be11a11695b01..2f96f7909c199 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -55,8 +55,11 @@ private[jobs] object UIData { var executorSummary = new HashMap[String, ExecutorSummary] } + /** + * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. + */ case class TaskUIData( - taskInfo: TaskInfo, - taskMetrics: Option[TaskMetrics] = None, - errorMessage: Option[String] = None) + var taskInfo: TaskInfo, + var taskMetrics: Option[TaskMetrics] = None, + var errorMessage: Option[String] = None) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 9930c717492f2..feafd654e9e71 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -18,13 +18,16 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap +import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} -import akka.actor.{ActorSystem, ExtendedActorSystem} +import akka.actor.{Actor, ActorRef, ActorSystem, ExtendedActorSystem} +import akka.pattern.ask + import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. @@ -124,4 +127,63 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 + + /** Returns the configured number of times to retry connecting */ + def numRetries(conf: SparkConf): Int = { + conf.getInt("spark.akka.num.retries", 3) + } + + /** Returns the configured number of milliseconds to wait on each retry */ + def retryWaitMs(conf: SparkConf): Int = { + conf.getInt("spark.akka.retry.wait", 3000) + } + + /** + * Send a message to the given actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + def askWithReply[T]( + message: Any, + actor: ActorRef, + retryAttempts: Int, + retryInterval: Int, + timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + if (actor == null) { + throw new SparkException("Error sending message as driverActor is null " + + "[message = " + message + "]") + } + var attempts = 0 + var lastException: Exception = null + while (attempts < retryAttempts) { + attempts += 1 + try { + val future = actor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message in " + attempts + " attempts", e) + } + Thread.sleep(retryInterval) + } + + throw new SparkException( + "Error sending message [message = " + message + "]", lastException) + } + + def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" + val timeout = AkkaUtils.lookupTimeout(conf) + logInfo(s"Connecting to $name: $url") + Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } } diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 9dcdafdd6350e..2e8fbf5a91ee7 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -52,7 +52,7 @@ private[spark] class FileLogger( override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } - private val fileSystem = Utils.getHadoopFileSystem(new URI(logDir)) + private val fileSystem = Utils.getHadoopFileSystem(logDir) var fileIndex = 0 // Only used if compression is enabled diff --git a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala index d769b54fa2fae..f77488ef3d449 100644 --- a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import org.apache.commons.lang.SystemUtils +import org.apache.commons.lang3.SystemUtils import org.slf4j.Logger import sun.misc.{Signal, SignalHandler} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8cbb9050f393b..30073a82857d2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -38,7 +38,7 @@ import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} @@ -286,17 +286,23 @@ private[spark] object Utils extends Logging { out: OutputStream, closeStreams: Boolean = false) { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) + try { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + } + } + } finally { + if (closeStreams) { + try { + in.close() + } finally { + out.close() + } } - } - if (closeStreams) { - in.close() - out.close() } } @@ -868,9 +874,12 @@ private[spark] object Utils extends Logging { val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) val stream = new FileInputStream(file) - stream.skip(effectiveStart) - stream.read(buff) - stream.close() + try { + stream.skip(effectiveStart) + stream.read(buff) + } finally { + stream.close() + } Source.fromBytes(buff).mkString } @@ -1313,4 +1322,13 @@ private[spark] object Utils extends Logging { s"$className: $desc\n$st" } + /** + * Convert all spark properties set in the given SparkConf to a sequence of java options. + */ + def sparkJavaOpts(conf: SparkConf, filterKey: (String => Boolean) = _ => true): Seq[String] = { + conf.getAll + .filter { case (k, _) => filterKey(k) } + .map { case (k, v) => s"-D$k=$v" } + } + } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 6f263c39d1435..cb67a1c039f20 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -79,12 +79,16 @@ class ExternalAppendOnlyMap[K, V, C]( (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } - // Number of pairs in the in-memory map - private var numPairsInMemory = 0L + // Number of pairs inserted since last spill; note that we count them even if a value is merged + // with a previous key in case we're doing something like groupBy where the result grows + private var elementsRead = 0L // Number of in-memory pairs inserted before tracking the map's shuffle memory usage private val trackMemoryThreshold = 1000 + // How much of the shared memory pool this collection has claimed + private var myMemoryThreshold = 0L + /** * Size of object batches when reading/writing from serializers. * @@ -106,7 +110,6 @@ class ExternalAppendOnlyMap[K, V, C]( private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() - private val threadId = Thread.currentThread().getId /** * Insert the given key and value into the map. @@ -134,31 +137,35 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) { - val mapSize = currentMap.estimateSize() + if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && + currentMap.estimateSize() >= myMemoryThreshold) + { + val currentSize = currentMap.estimateSize() var shouldSpill = false val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap // Atomically check whether there is sufficient memory in the global pool for // this map to grow and, if possible, allocate the required amount shuffleMemoryMap.synchronized { + val threadId = Thread.currentThread().getId val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) val availableMemory = maxMemoryThreshold - (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) - // Assume map growth factor is 2x - shouldSpill = availableMemory < mapSize * 2 + // Try to allocate at least 2x more memory, otherwise spill + shouldSpill = availableMemory < currentSize * 2 if (!shouldSpill) { - shuffleMemoryMap(threadId) = mapSize * 2 + shuffleMemoryMap(threadId) = currentSize * 2 + myMemoryThreshold = currentSize * 2 } } // Do not synchronize spills if (shouldSpill) { - spill(mapSize) + spill(currentSize) } } currentMap.changeValue(curEntry._1, update) - numPairsInMemory += 1 + elementsRead += 1 } } @@ -178,9 +185,10 @@ class ExternalAppendOnlyMap[K, V, C]( /** * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ - private def spill(mapSize: Long) { + private def spill(mapSize: Long): Unit = { spillCount += 1 - logWarning("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" + val threadId = Thread.currentThread().getId + logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) @@ -191,7 +199,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables def flush() = { - writer.commit() + writer.commitAndClose() val bytesWritten = writer.bytesWritten batchSizes.append(bytesWritten) _diskBytesSpilled += bytesWritten @@ -227,7 +235,9 @@ class ExternalAppendOnlyMap[K, V, C]( shuffleMemoryMap.synchronized { shuffleMemoryMap(Thread.currentThread().getId) = 0 } - numPairsInMemory = 0 + myMemoryThreshold = 0 + + elementsRead = 0 _memoryBytesSpilled += mapSize } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala new file mode 100644 index 0000000000000..6e415a2bd8ce2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -0,0 +1,662 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io._ +import java.util.Comparator + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable + +import com.google.common.io.ByteStreams + +import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.BlockId + +/** + * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner + * pairs of type (K, C). Uses a Partitioner to first group the keys into partitions, and then + * optionally sorts keys within each partition using a custom Comparator. Can output a single + * partitioned file with a different byte range for each partition, suitable for shuffle fetches. + * + * If combining is disabled, the type C must equal V -- we'll cast the objects at the end. + * + * @param aggregator optional Aggregator with combine functions to use for merging data + * @param partitioner optional Partitioner; if given, sort by partition ID and then key + * @param ordering optional Ordering to sort keys within each partition; should be a total ordering + * @param serializer serializer to use when spilling to disk + * + * Note that if an Ordering is given, we'll always sort using it, so only provide it if you really + * want the output keys to be sorted. In a map task without map-side combine for example, you + * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do + * want to do combining, having an Ordering is more efficient than not having it. + * + * At a high level, this class works as follows: + * + * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if + * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers, + * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to + * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner). + * + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. + * + * - When the user requests an iterator, the spilled files are merged, along with any remaining + * in-memory data, using the same sort order defined above (unless both sorting and aggregation + * are disabled). If we need to aggregate by key, we either use a total ordering from the + * ordering parameter, or read the keys with the same hash code and compare them with each other + * for equality to merge values. + * + * - Users are expected to call stop() at the end to delete all the intermediate files. + */ +private[spark] class ExternalSorter[K, V, C]( + aggregator: Option[Aggregator[K, V, C]] = None, + partitioner: Option[Partitioner] = None, + ordering: Option[Ordering[K]] = None, + serializer: Option[Serializer] = None) extends Logging { + + private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) + private val shouldPartition = numPartitions > 1 + + private val blockManager = SparkEnv.get.blockManager + private val diskBlockManager = blockManager.diskBlockManager + private val ser = Serializer.getSerializer(serializer) + private val serInstance = ser.newInstance() + + private val conf = SparkEnv.get.conf + private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + + // Size of object batches when reading/writing from serializers. + // + // Objects are written in batches, with each batch using its own serialization stream. This + // cuts down on the size of reference-tracking maps constructed when deserializing a stream. + // + // NOTE: Setting this too low can cause excessive copying when serializing, since some serializers + // grow internal data structures by growing + copying every time the number of objects doubles. + private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) + + private def getPartition(key: K): Int = { + if (shouldPartition) partitioner.get.getPartition(key) else 0 + } + + // Data structures to store in-memory objects before we spill. Depending on whether we have an + // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we + // store them in an array buffer. + private var map = new SizeTrackingAppendOnlyMap[(Int, K), C] + private var buffer = new SizeTrackingPairBuffer[(Int, K), C] + + // Number of pairs read from input since last spill; note that we count them even if a value is + // merged with a previous key in case we're doing something like groupBy where the result grows + private var elementsRead = 0L + + // What threshold of elementsRead we start estimating map size at. + private val trackMemoryThreshold = 1000 + + // Spilling statistics + private var spillCount = 0 + private var _memoryBytesSpilled = 0L + private var _diskBytesSpilled = 0L + + // Collective memory threshold shared across all running tasks + private val maxMemoryThreshold = { + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } + + // How much of the shared memory pool this collection has claimed + private var myMemoryThreshold = 0L + + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. + // Can be a partial ordering by hash code if a total ordering is not provided through by the + // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some + // non-equal keys also have this, so we need to do a later pass to find truly equal keys). + // Note that we ignore this if no aggregator and no ordering are given. + private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] { + override def compare(a: K, b: K): Int = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + h1 - h2 + } + }) + + // A comparator for (Int, K) elements that orders them by partition and then possibly by key + private val partitionKeyComparator: Comparator[(Int, K)] = { + if (ordering.isDefined || aggregator.isDefined) { + // Sort by partition ID then key comparator + new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + val partitionDiff = a._1 - b._1 + if (partitionDiff != 0) { + partitionDiff + } else { + keyComparator.compare(a._2, b._2) + } + } + } + } else { + // Just sort it by partition ID + new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + } + } + + // Information about a spilled file. Includes sizes in bytes of "batches" written by the + // serializer as we periodically reset its stream, as well as number of elements in each + // partition, used to efficiently keep track of partitions when merging. + private[this] case class SpilledFile( + file: File, + blockId: BlockId, + serializerBatchSizes: Array[Long], + elementsPerPartition: Array[Long]) + private val spills = new ArrayBuffer[SpilledFile] + + def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + // TODO: stop combining if we find that the reduction factor isn't high + val shouldCombine = aggregator.isDefined + + if (shouldCombine) { + // Combine values in-memory first using our AppendOnlyMap + val mergeValue = aggregator.get.mergeValue + val createCombiner = aggregator.get.createCombiner + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (records.hasNext) { + elementsRead += 1 + kv = records.next() + map.changeValue((getPartition(kv._1), kv._1), update) + maybeSpill(usingMap = true) + } + } else { + // Stick values into our buffer + while (records.hasNext) { + elementsRead += 1 + val kv = records.next() + buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + maybeSpill(usingMap = false) + } + } + } + + /** + * Spill the current in-memory collection to disk if needed. + * + * @param usingMap whether we're using a map or buffer as our current in-memory collection + */ + private def maybeSpill(usingMap: Boolean): Unit = { + if (!spillingEnabled) { + return + } + + val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer + + // TODO: factor this out of both here and ExternalAppendOnlyMap + if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && + collection.estimateSize() >= myMemoryThreshold) + { + // TODO: This logic doesn't work if there are two external collections being used in the same + // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711] + + val currentSize = collection.estimateSize() + var shouldSpill = false + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + + // Atomically check whether there is sufficient memory in the global pool for + // us to double our threshold + shuffleMemoryMap.synchronized { + val threadId = Thread.currentThread().getId + val previouslyClaimedMemory = shuffleMemoryMap.get(threadId) + val availableMemory = maxMemoryThreshold - + (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L)) + + // Try to allocate at least 2x more memory, otherwise spill + shouldSpill = availableMemory < currentSize * 2 + if (!shouldSpill) { + shuffleMemoryMap(threadId) = currentSize * 2 + myMemoryThreshold = currentSize * 2 + } + } + // Do not hold lock during spills + if (shouldSpill) { + spill(currentSize, usingMap) + } + } + } + + /** + * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. + * + * @param usingMap whether we're using a map or buffer as our current in-memory collection + */ + private def spill(memorySize: Long, usingMap: Boolean): Unit = { + val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer + val memorySize = collection.estimateSize() + + spillCount += 1 + val threadId = Thread.currentThread().getId + logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" + .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + val (blockId, file) = diskBlockManager.createTempBlock() + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + var objectsWritten = 0 // Objects written since the last flush + + // List of batch sizes (bytes) in the order they are written to disk + val batchSizes = new ArrayBuffer[Long] + + // How many elements we have in each partition + val elementsPerPartition = new Array[Long](numPartitions) + + // Flush the disk writer's contents to disk, and update relevant variables. + // The writer is closed at the end of this process, and cannot be reused. + def flush() = { + writer.commitAndClose() + val bytesWritten = writer.bytesWritten + batchSizes.append(bytesWritten) + _diskBytesSpilled += bytesWritten + objectsWritten = 0 + } + + try { + val it = collection.destructiveSortedIterator(partitionKeyComparator) + while (it.hasNext) { + val elem = it.next() + val partitionId = elem._1._1 + val key = elem._1._2 + val value = elem._2 + writer.write(key) + writer.write(value) + elementsPerPartition(partitionId) += 1 + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + flush() + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + } + } + if (objectsWritten > 0) { + flush() + } + writer.close() + } catch { + case e: Exception => + writer.close() + file.delete() + throw e + } + + if (usingMap) { + map = new SizeTrackingAppendOnlyMap[(Int, K), C] + } else { + buffer = new SizeTrackingPairBuffer[(Int, K), C] + } + + // Reset the amount of shuffle memory used by this map in the global pool + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + shuffleMemoryMap.synchronized { + shuffleMemoryMap(Thread.currentThread().getId) = 0 + } + myMemoryThreshold = 0 + + spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) + _memoryBytesSpilled += memorySize + } + + /** + * Merge a sequence of sorted files, giving an iterator over partitions and then over elements + * inside each partition. This can be used to either write out a new file or return data to + * the user. + * + * Returns an iterator over all the data written to this object, grouped by partition. For each + * partition we then have an iterator over its contents, and these are expected to be accessed + * in order (you can't "skip ahead" to one partition without reading the previous one). + * Guaranteed to return a key-value pair for each partition, in order of partition ID. + */ + private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)]) + : Iterator[(Int, Iterator[Product2[K, C]])] = { + val readers = spills.map(new SpillReader(_)) + val inMemBuffered = inMemory.buffered + (0 until numPartitions).iterator.map { p => + val inMemIterator = new IteratorForPartition(p, inMemBuffered) + val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator) + if (aggregator.isDefined) { + // Perform partial aggregation across partitions + (p, mergeWithAggregation( + iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined)) + } else if (ordering.isDefined) { + // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey); + // sort the elements without trying to merge them + (p, mergeSort(iterators, ordering.get)) + } else { + (p, iterators.iterator.flatten) + } + } + } + + /** + * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys. + */ + private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) + : Iterator[Product2[K, C]] = + { + val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) + type Iter = BufferedIterator[Product2[K, C]] + val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { + // Use the reverse of comparator.compare because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + }) + heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true + new Iterator[Product2[K, C]] { + override def hasNext: Boolean = !heap.isEmpty + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstBuf = heap.dequeue() + val firstPair = firstBuf.next() + if (firstBuf.hasNext) { + heap.enqueue(firstBuf) + } + firstPair + } + } + } + + /** + * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each + * iterator is sorted by key with a given comparator. If the comparator is not a total ordering + * (e.g. when we sort objects by hash code and different keys may compare as equal although + * they're not), we still merge them by doing equality tests for all keys that compare as equal. + */ + private def mergeWithAggregation( + iterators: Seq[Iterator[Product2[K, C]]], + mergeCombiners: (C, C) => C, + comparator: Comparator[K], + totalOrder: Boolean) + : Iterator[Product2[K, C]] = + { + if (!totalOrder) { + // We only have a partial ordering, e.g. comparing the keys by hash code, which means that + // multiple distinct keys might be treated as equal by the ordering. To deal with this, we + // need to read all keys considered equal by the ordering at once and compare them. + new Iterator[Iterator[Product2[K, C]]] { + val sorted = mergeSort(iterators, comparator).buffered + + // Buffers reused across elements to decrease memory allocation + val keys = new ArrayBuffer[K] + val combiners = new ArrayBuffer[C] + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Iterator[Product2[K, C]] = { + if (!hasNext) { + throw new NoSuchElementException + } + keys.clear() + combiners.clear() + val firstPair = sorted.next() + keys += firstPair._1 + combiners += firstPair._2 + val key = firstPair._1 + while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) { + val pair = sorted.next() + var i = 0 + var foundKey = false + while (i < keys.size && !foundKey) { + if (keys(i) == pair._1) { + combiners(i) = mergeCombiners(combiners(i), pair._2) + foundKey = true + } + i += 1 + } + if (!foundKey) { + keys += pair._1 + combiners += pair._2 + } + } + + // Note that we return an iterator of elements since we could've had many keys marked + // equal by the partial order; we flatten this below to get a flat iterator of (K, C). + keys.iterator.zip(combiners.iterator) + } + }.flatMap(i => i) + } else { + // We have a total ordering, so the objects with the same key are sequential. + new Iterator[Product2[K, C]] { + val sorted = mergeSort(iterators, comparator).buffered + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val elem = sorted.next() + val k = elem._1 + var c = elem._2 + while (sorted.hasNext && sorted.head._1 == k) { + c = mergeCombiners(c, sorted.head._2) + } + (k, c) + } + } + } + } + + /** + * An internal class for reading a spilled file partition by partition. Expects all the + * partitions to be requested in order. + */ + private[this] class SpillReader(spill: SpilledFile) { + val fileStream = new FileInputStream(spill.file) + val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) + + // Track which partition and which batch stream we're in. These will be the indices of + // the next element we will read. We'll also store the last partition read so that + // readNextPartition() can figure out what partition that was from. + var partitionId = 0 + var indexInPartition = 0L + var batchStreamsRead = 0 + var indexInBatch = 0 + var lastPartitionId = 0 + + skipToNextPartition() + + // An intermediate stream that reads from exactly one batch + // This guards against pre-fetching and other arbitrary behavior of higher level streams + var batchStream = nextBatchStream() + var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) + var deserStream = serInstance.deserializeStream(compressedStream) + var nextItem: (K, C) = null + var finished = false + + /** Construct a stream that only reads from the next batch */ + def nextBatchStream(): InputStream = { + if (batchStreamsRead < spill.serializerBatchSizes.length) { + batchStreamsRead += 1 + ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1)) + } else { + // No more batches left; give an empty stream + bufferedStream + } + } + + /** + * Update partitionId if we have reached the end of our current partition, possibly skipping + * empty partitions on the way. + */ + private def skipToNextPartition() { + while (partitionId < numPartitions && + indexInPartition == spill.elementsPerPartition(partitionId)) { + partitionId += 1 + indexInPartition = 0L + } + } + + /** + * Return the next (K, C) pair from the deserialization stream and update partitionId, + * indexInPartition, indexInBatch and such to match its location. + * + * If the current batch is drained, construct a stream for the next batch and read from it. + * If no more pairs are left, return null. + */ + private def readNextItem(): (K, C) = { + if (finished) { + return null + } + val k = deserStream.readObject().asInstanceOf[K] + val c = deserStream.readObject().asInstanceOf[C] + lastPartitionId = partitionId + // Start reading the next batch if we're done with this one + indexInBatch += 1 + if (indexInBatch == serializerBatchSize) { + batchStream = nextBatchStream() + compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) + deserStream = serInstance.deserializeStream(compressedStream) + indexInBatch = 0 + } + // Update the partition location of the element we're reading + indexInPartition += 1 + skipToNextPartition() + // If we've finished reading the last partition, remember that we're done + if (partitionId == numPartitions) { + finished = true + deserStream.close() + } + (k, c) + } + + var nextPartitionToRead = 0 + + def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] { + val myPartition = nextPartitionToRead + nextPartitionToRead += 1 + + override def hasNext: Boolean = { + if (nextItem == null) { + nextItem = readNextItem() + if (nextItem == null) { + return false + } + } + assert(lastPartitionId >= myPartition) + // Check that we're still in the right partition; note that readNextItem will have returned + // null at EOF above so we would've returned false there + lastPartitionId == myPartition + } + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val item = nextItem + nextItem = null + item + } + } + } + + /** + * Return an iterator over all the data written to this object, grouped by partition and + * aggregated by the requested aggregator. For each partition we then have an iterator over its + * contents, and these are expected to be accessed in order (you can't "skip ahead" to one + * partition without reading the previous one). Guaranteed to return a key-value pair for each + * partition, in order of partition ID. + * + * For now, we just merge all the spilled files in once pass, but this can be modified to + * support hierarchical merging. + */ + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + val usingMap = aggregator.isDefined + val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer + if (spills.isEmpty) { + // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps + // we don't even need to sort by anything other than partition ID + if (!ordering.isDefined) { + // The user isn't requested sorted keys, so only sort by partition ID, not key + val partitionComparator = new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + } else { + // We do need to sort by both partition ID and key + groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator)) + } + } else { + // General case: merge spilled and in-memory data + merge(spills, collection.destructiveSortedIterator(partitionKeyComparator)) + } + } + + /** + * Return an iterator over all the data written to this object, aggregated by our aggregator. + */ + def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) + + def stop(): Unit = { + spills.foreach(s => s.file.delete()) + spills.clear() + } + + def memoryBytesSpilled: Long = _memoryBytesSpilled + + def diskBytesSpilled: Long = _diskBytesSpilled + + /** + * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*, + * group together the pairs for each partition into a sub-iterator. + * + * @param data an iterator of elements, assumed to already be sorted by partition ID + */ + private def groupByPartition(data: Iterator[((Int, K), C)]) + : Iterator[(Int, Iterator[Product2[K, C]])] = + { + val buffered = data.buffered + (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered))) + } + + /** + * An iterator that reads only the elements for a given partition ID from an underlying buffered + * stream, assuming this partition is the next one to be read. Used to make it easier to return + * partitioned iterators from our in-memory collection. + */ + private[this] class IteratorForPartition(partitionId: Int, data: BufferedIterator[((Int, K), C)]) + extends Iterator[Product2[K, C]] + { + override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val elem = data.next() + (elem._1._2, elem._2) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala index de61e1d17fe10..eb4de413867a0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -20,8 +20,9 @@ package org.apache.spark.util.collection /** * An append-only map that keeps track of its estimated size in bytes. */ -private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] with SizeTracker { - +private[spark] class SizeTrackingAppendOnlyMap[K, V] + extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V] +{ override def update(key: K, value: V): Unit = { super.update(key, value) super.afterUpdate() diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala new file mode 100644 index 0000000000000..9e9c16c5a2962 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +/** + * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes. + */ +private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64) + extends SizeTracker with SizeTrackingPairCollection[K, V] +{ + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + + // Basic growable array data structure. We use a single array of AnyRef to hold both the keys + // and the values, so that we can sort them efficiently with KVArraySortDataFormat. + private var capacity = initialCapacity + private var curSize = 0 + private var data = new Array[AnyRef](2 * initialCapacity) + + /** Add an element into the buffer */ + def insert(key: K, value: V): Unit = { + if (curSize == capacity) { + growArray() + } + data(2 * curSize) = key.asInstanceOf[AnyRef] + data(2 * curSize + 1) = value.asInstanceOf[AnyRef] + curSize += 1 + afterUpdate() + } + + /** Total number of elements in buffer */ + override def size: Int = curSize + + /** Iterate over the elements of the buffer */ + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { + var pos = 0 + + override def hasNext: Boolean = pos < curSize + + override def next(): (K, V) = { + if (!hasNext) { + throw new NoSuchElementException + } + val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + pos += 1 + pair + } + } + + /** Double the size of the array because we've reached capacity */ + private def growArray(): Unit = { + if (capacity == (1 << 29)) { + // Doubling the capacity would create an array bigger than Int.MaxValue, so don't + throw new Exception("Can't grow buffer beyond 2^29 elements") + } + val newCapacity = capacity * 2 + val newArray = new Array[AnyRef](2 * newCapacity) + System.arraycopy(data, 0, newArray, 0, 2 * capacity) + data = newArray + capacity = newCapacity + resetSamples() + } + + /** Iterate through the data in a given order. For this class this is not really destructive. */ + override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = { + new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator) + iterator + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala new file mode 100644 index 0000000000000..faa4e2b12ddb6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +/** + * A common interface for our size-tracking collections of key-value pairs, which are used in + * external operations. These all support estimating the size and obtaining a memory-efficient + * sorted iterator. + */ +// TODO: should extend Iterable[Product2[K, V]] instead of (K, V) +private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] { + /** Estimate the collection's current memory usage in bytes. */ + def estimateSize(): Long + + /** Iterate through the data in a given key order. This may destroy the underlying collection. */ + def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] +} diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index d10141b90e621..c9a864ae62778 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -81,6 +81,9 @@ private[spark] object SamplingUtils { * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success * rate, where success rate is defined the same as in sampling with replacement. * + * The smallest sampling rate supported is 1e-10 (in order to avoid running into the limit of the + * RNG's resolution). + * * @param sampleSizeLowerBound sample size * @param total size of RDD * @param withReplacement whether sampling with replacement @@ -88,14 +91,73 @@ private[spark] object SamplingUtils { */ def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, withReplacement: Boolean): Double = { - val fraction = sampleSizeLowerBound.toDouble / total if (withReplacement) { - val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 - fraction + numStDev * math.sqrt(fraction / total) + PoissonBounds.getUpperBound(sampleSizeLowerBound) / total } else { - val delta = 1e-4 - val gamma = - math.log(delta) / total - math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) + val fraction = sampleSizeLowerBound.toDouble / total + BinomialBounds.getUpperBound(1e-4, total, fraction) } } } + +/** + * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact + * sample sizes with high confidence when sampling with replacement. + */ +private[spark] object PoissonBounds { + + /** + * Returns a lambda such that Pr[X > s] is very small, where X ~ Pois(lambda). + */ + def getLowerBound(s: Double): Double = { + math.max(s - numStd(s) * math.sqrt(s), 1e-15) + } + + /** + * Returns a lambda such that Pr[X < s] is very small, where X ~ Pois(lambda). + * + * @param s sample size + */ + def getUpperBound(s: Double): Double = { + math.max(s + numStd(s) * math.sqrt(s), 1e-10) + } + + private def numStd(s: Double): Double = { + // TODO: Make it tighter. + if (s < 6.0) { + 12.0 + } else if (s < 16.0) { + 9.0 + } else { + 6.0 + } + } +} + +/** + * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact + * sample size with high confidence when sampling without replacement. + */ +private[spark] object BinomialBounds { + + val minSamplingRate = 1e-10 + + /** + * Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`, + * it is very unlikely to have more than `fraction * n` successes. + */ + def getLowerBound(delta: Double, n: Long, fraction: Double): Double = { + val gamma = - math.log(delta) / n * (2.0 / 3.0) + fraction + gamma - math.sqrt(gamma * gamma + 3 * gamma * fraction) + } + + /** + * Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`, + * it is very unlikely to have less than `fraction * n` successes. + */ + def getUpperBound(delta: Double, n: Long, fraction: Double): Double = { + val gamma = - math.log(delta) / n + math.min(1, + math.max(minSamplingRate, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala new file mode 100644 index 0000000000000..8f95d7c6b799b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import scala.collection.Map +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD + +/** + * Auxiliary functions and data structures for the sampleByKey method in PairRDDFunctions. + * + * Essentially, when exact sample size is necessary, we make additional passes over the RDD to + * compute the exact threshold value to use for each stratum to guarantee exact sample size with + * high probability. This is achieved by maintaining a waitlist of size O(log(s)), where s is the + * desired sample size for each stratum. + * + * Like in simple random sampling, we generate a random value for each item from the + * uniform distribution [0.0, 1.0]. All items with values <= min(values of items in the waitlist) + * are accepted into the sample instantly. The threshold for instant accept is designed so that + * s - numAccepted = O(sqrt(s)), where s is again the desired sample size. Thus, by maintaining a + * waitlist size = O(sqrt(s)), we will be able to create a sample of the exact size s by adding + * a portion of the waitlist to the set of items that are instantly accepted. The exact threshold + * is computed by sorting the values in the waitlist and picking the value at (s - numAccepted). + * + * Note that since we use the same seed for the RNG when computing the thresholds and the actual + * sample, our computed thresholds are guaranteed to produce the desired sample size. + * + * For more theoretical background on the sampling techniques used here, please refer to + * http://jmlr.org/proceedings/papers/v28/meng13a.html + */ + +private[spark] object StratifiedSamplingUtils extends Logging { + + /** + * Count the number of items instantly accepted and generate the waitlist for each stratum. + * + * This is only invoked when exact sample size is required. + */ + def getAcceptanceResults[K, V](rdd: RDD[(K, V)], + withReplacement: Boolean, + fractions: Map[K, Double], + counts: Option[Map[K, Long]], + seed: Long): mutable.Map[K, AcceptanceResult] = { + val combOp = getCombOp[K] + val mappedPartitionRDD = rdd.mapPartitionsWithIndex { case (partition, iter) => + val zeroU: mutable.Map[K, AcceptanceResult] = new mutable.HashMap[K, AcceptanceResult]() + val rng = new RandomDataGenerator() + rng.reSeed(seed + partition) + val seqOp = getSeqOp(withReplacement, fractions, rng, counts) + Iterator(iter.aggregate(zeroU)(seqOp, combOp)) + } + mappedPartitionRDD.reduce(combOp) + } + + /** + * Returns the function used by aggregate to collect sampling statistics for each partition. + */ + def getSeqOp[K, V](withReplacement: Boolean, + fractions: Map[K, Double], + rng: RandomDataGenerator, + counts: Option[Map[K, Long]]): + (mutable.Map[K, AcceptanceResult], (K, V)) => mutable.Map[K, AcceptanceResult] = { + val delta = 5e-5 + (result: mutable.Map[K, AcceptanceResult], item: (K, V)) => { + val key = item._1 + val fraction = fractions(key) + if (!result.contains(key)) { + result += (key -> new AcceptanceResult()) + } + val acceptResult = result(key) + + if (withReplacement) { + // compute acceptBound and waitListBound only if they haven't been computed already + // since they don't change from iteration to iteration. + // TODO change this to the streaming version + if (acceptResult.areBoundsEmpty) { + val n = counts.get(key) + val sampleSize = math.ceil(n * fraction).toLong + val lmbd1 = PoissonBounds.getLowerBound(sampleSize) + val lmbd2 = PoissonBounds.getUpperBound(sampleSize) + acceptResult.acceptBound = lmbd1 / n + acceptResult.waitListBound = (lmbd2 - lmbd1) / n + } + val acceptBound = acceptResult.acceptBound + val copiesAccepted = if (acceptBound == 0.0) 0L else rng.nextPoisson(acceptBound) + if (copiesAccepted > 0) { + acceptResult.numAccepted += copiesAccepted + } + val copiesWaitlisted = rng.nextPoisson(acceptResult.waitListBound) + if (copiesWaitlisted > 0) { + acceptResult.waitList ++= ArrayBuffer.fill(copiesWaitlisted)(rng.nextUniform()) + } + } else { + // We use the streaming version of the algorithm for sampling without replacement to avoid + // using an extra pass over the RDD for computing the count. + // Hence, acceptBound and waitListBound change on every iteration. + acceptResult.acceptBound = + BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction) + acceptResult.waitListBound = + BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction) + + val x = rng.nextUniform() + if (x < acceptResult.acceptBound) { + acceptResult.numAccepted += 1 + } else if (x < acceptResult.waitListBound) { + acceptResult.waitList += x + } + } + acceptResult.numItems += 1 + result + } + } + + /** + * Returns the function used combine results returned by seqOp from different partitions. + */ + def getCombOp[K]: (mutable.Map[K, AcceptanceResult], mutable.Map[K, AcceptanceResult]) + => mutable.Map[K, AcceptanceResult] = { + (result1: mutable.Map[K, AcceptanceResult], result2: mutable.Map[K, AcceptanceResult]) => { + // take union of both key sets in case one partition doesn't contain all keys + result1.keySet.union(result2.keySet).foreach { key => + // Use result2 to keep the combined result since r1 is usual empty + val entry1 = result1.get(key) + if (result2.contains(key)) { + result2(key).merge(entry1) + } else { + if (entry1.isDefined) { + result2 += (key -> entry1.get) + } + } + } + result2 + } + } + + /** + * Given the result returned by getCounts, determine the threshold for accepting items to + * generate exact sample size. + * + * To do so, we compute sampleSize = math.ceil(size * samplingRate) for each stratum and compare + * it to the number of items that were accepted instantly and the number of items in the waitlist + * for that stratum. Most of the time, numAccepted <= sampleSize <= (numAccepted + numWaitlisted), + * which means we need to sort the elements in the waitlist by their associated values in order + * to find the value T s.t. |{elements in the stratum whose associated values <= T}| = sampleSize. + * Note that all elements in the waitlist have values >= bound for instant accept, so a T value + * in the waitlist range would allow all elements that were instantly accepted on the first pass + * to be included in the sample. + */ + def computeThresholdByKey[K](finalResult: Map[K, AcceptanceResult], + fractions: Map[K, Double]): Map[K, Double] = { + val thresholdByKey = new mutable.HashMap[K, Double]() + for ((key, acceptResult) <- finalResult) { + val sampleSize = math.ceil(acceptResult.numItems * fractions(key)).toLong + if (acceptResult.numAccepted > sampleSize) { + logWarning("Pre-accepted too many") + thresholdByKey += (key -> acceptResult.acceptBound) + } else { + val numWaitListAccepted = (sampleSize - acceptResult.numAccepted).toInt + if (numWaitListAccepted >= acceptResult.waitList.size) { + logWarning("WaitList too short") + thresholdByKey += (key -> acceptResult.waitListBound) + } else { + thresholdByKey += (key -> acceptResult.waitList.sorted.apply(numWaitListAccepted)) + } + } + } + thresholdByKey + } + + /** + * Return the per partition sampling function used for sampling without replacement. + * + * When exact sample size is required, we make an additional pass over the RDD to determine the + * exact sampling rate that guarantees sample size with high confidence. + * + * The sampling function has a unique seed per partition. + */ + def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], + fractions: Map[K, Double], + exact: Boolean, + seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + var samplingRateByKey = fractions + if (exact) { + // determine threshold for each stratum and resample + val finalResult = getAcceptanceResults(rdd, false, fractions, None, seed) + samplingRateByKey = computeThresholdByKey(finalResult, fractions) + } + (idx: Int, iter: Iterator[(K, V)]) => { + val rng = new RandomDataGenerator + rng.reSeed(seed + idx) + // Must use the same invoke pattern on the rng as in getSeqOp for without replacement + // in order to generate the same sequence of random numbers when creating the sample + iter.filter(t => rng.nextUniform() < samplingRateByKey(t._1)) + } + } + + /** + * Return the per partition sampling function used for sampling with replacement. + * + * When exact sample size is required, we make two additional passed over the RDD to determine + * the exact sampling rate that guarantees sample size with high confidence. The first pass + * counts the number of items in each stratum (group of items with the same key) in the RDD, and + * the second pass uses the counts to determine exact sampling rates. + * + * The sampling function has a unique seed per partition. + */ + def getPoissonSamplingFunction[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], + fractions: Map[K, Double], + exact: Boolean, + seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + // TODO implement the streaming version of sampling w/ replacement that doesn't require counts + if (exact) { + val counts = Some(rdd.countByKey()) + val finalResult = getAcceptanceResults(rdd, true, fractions, counts, seed) + val thresholdByKey = computeThresholdByKey(finalResult, fractions) + (idx: Int, iter: Iterator[(K, V)]) => { + val rng = new RandomDataGenerator() + rng.reSeed(seed + idx) + iter.flatMap { item => + val key = item._1 + val acceptBound = finalResult(key).acceptBound + // Must use the same invoke pattern on the rng as in getSeqOp for with replacement + // in order to generate the same sequence of random numbers when creating the sample + val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound) + val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound) + val copiesInSample = copiesAccepted + + (0 until copiesWailisted).count(i => rng.nextUniform() < thresholdByKey(key)) + if (copiesInSample > 0) { + Iterator.fill(copiesInSample.toInt)(item) + } else { + Iterator.empty + } + } + } + } else { + (idx: Int, iter: Iterator[(K, V)]) => { + val rng = new RandomDataGenerator() + rng.reSeed(seed + idx) + iter.flatMap { item => + val count = rng.nextPoisson(fractions(item._1)) + if (count > 0) { + Iterator.fill(count)(item) + } else { + Iterator.empty + } + } + } + } + } + + /** A random data generator that generates both uniform values and Poisson values. */ + private class RandomDataGenerator { + val uniform = new XORShiftRandom() + var poisson = new Poisson(1.0, new DRand) + + def reSeed(seed: Long) { + uniform.setSeed(seed) + poisson = new Poisson(1.0, new DRand(seed.toInt)) + } + + def nextPoisson(mean: Double): Int = { + poisson.nextInt(mean) + } + + def nextUniform(): Double = { + uniform.nextDouble() + } + } +} + +/** + * Object used by seqOp to keep track of the number of items accepted and items waitlisted per + * stratum, as well as the bounds for accepting and waitlisting items. + * + * `[random]` here is necessary since it's in the return type signature of seqOp defined above + */ +private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: Long = 0L) + extends Serializable { + + val waitList = new ArrayBuffer[Double] + var acceptBound: Double = Double.NaN // upper bound for accepting item instantly + var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist + + def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN + + def merge(other: Option[AcceptanceResult]): Unit = { + if (other.isDefined) { + waitList ++= other.get.waitList + numAccepted += other.get.numAccepted + numItems += other.get.numItems + } + } +} diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index f882a8623fd84..56150caa5d6ba 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -25,18 +25,23 @@ import scala.Tuple3; import scala.Tuple4; - import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.mapred.FileSplit; +import org.apache.hadoop.mapred.InputSplit; import org.apache.hadoop.mapred.SequenceFileInputFormat; import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.hadoop.mapred.TextInputFormat; import org.apache.hadoop.mapreduce.Job; import org.junit.After; import org.junit.Assert; @@ -44,6 +49,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaHadoopRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -1208,4 +1214,76 @@ public Tuple2 call(Integer x) { pairRDD.collect(); // Works fine pairRDD.collectAsMap(); // Used to crash with ClassCastException } + + @Test + @SuppressWarnings("unchecked") + public void sampleByKey() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaPairRDD rdd2 = rdd1.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer i) { + return new Tuple2(i % 2, 1); + } + }); + Map fractions = Maps.newHashMap(); + fractions.put(0, 0.5); + fractions.put(1, 1.0); + JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); + Map wrCounts = (Map) (Object) wr.countByKey(); + Assert.assertTrue(wrCounts.size() == 2); + Assert.assertTrue(wrCounts.get(0) > 0); + Assert.assertTrue(wrCounts.get(1) > 0); + JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); + Map worCounts = (Map) (Object) wor.countByKey(); + Assert.assertTrue(worCounts.size() == 2); + Assert.assertTrue(worCounts.get(0) > 0); + Assert.assertTrue(worCounts.get(1) > 0); + JavaPairRDD wrExact = rdd2.sampleByKey(true, fractions, true, 1L); + Map wrExactCounts = (Map) (Object) wrExact.countByKey(); + Assert.assertTrue(wrExactCounts.size() == 2); + Assert.assertTrue(wrExactCounts.get(0) == 2); + Assert.assertTrue(wrExactCounts.get(1) == 4); + JavaPairRDD worExact = rdd2.sampleByKey(false, fractions, true, 1L); + Map worExactCounts = (Map) (Object) worExact.countByKey(); + Assert.assertTrue(worExactCounts.size() == 2); + Assert.assertTrue(worExactCounts.get(0) == 2); + Assert.assertTrue(worExactCounts.get(1) == 4); + } + + private static class SomeCustomClass implements Serializable { + public SomeCustomClass() { + // Intentionally left blank + } + } + + @Test + public void collectUnderlyingScalaRDD() { + List data = new ArrayList(); + for (int i = 0; i < 100; i++) { + data.add(new SomeCustomClass()); + } + JavaRDD rdd = sc.parallelize(data); + SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect(); + Assert.assertEquals(data.size(), collected.length); + } + + public void getHadoopInputSplits() { + String outDir = new File(tempDir, "output").getAbsolutePath(); + sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2).saveAsTextFile(outDir); + + JavaHadoopRDD hadoopRDD = (JavaHadoopRDD) + sc.hadoopFile(outDir, TextInputFormat.class, LongWritable.class, Text.class); + List inputPaths = hadoopRDD.mapPartitionsWithInputSplit( + new Function2>, Iterator>() { + @Override + public Iterator call(InputSplit split, Iterator> it) + throws Exception { + FileSplit fileSplit = (FileSplit) split; + return Lists.newArrayList(fileSplit.getPath().toUri().getPath()).iterator(); + } + }, true).collect(); + Assert.assertEquals(Sets.newHashSet(inputPaths), + Sets.newHashSet(outDir + "/part-00000", outDir + "/part-00001")); + } } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d1cb2d9d3a53b..a41914a1a9d0c 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("ShuffledRDD") { testRDD(rdd => { // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD - new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) + new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner) }) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 13b415cccb647..4bc4346c0a288 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -31,15 +31,28 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId} - -class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - +import org.apache.spark.storage._ +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.storage.BroadcastBlockId +import org.apache.spark.storage.RDDBlockId +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.ShuffleIndexBlockId + +/** + * An abstract base class for context cleaner tests, which sets up a context with a config + * suitable for cleaner tests and provides some utility functions. Subclasses can use different + * config options, in particular, a different shuffle manager class + */ +abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager]) + extends FunSuite with BeforeAndAfter with LocalSparkContext +{ implicit val defaultTimeout = timeout(10000 millis) val conf = new SparkConf() .setMaster("local[2]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.shuffle.manager", shuffleManager.getName) before { sc = new SparkContext(conf) @@ -52,9 +65,61 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } } + //------ Helper functions ------ + + protected def newRDD() = sc.makeRDD(1 to 10) + protected def newPairRDD() = newRDD().map(_ -> 1) + protected def newShuffleRDD() = newPairRDD().reduceByKey(_ + _) + protected def newBroadcast() = sc.broadcast(1 to 100) + + protected def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { + def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { + rdd.dependencies ++ rdd.dependencies.flatMap { dep => + getAllDependencies(dep.rdd) + } + } + val rdd = newShuffleRDD() + + // Get all the shuffle dependencies + val shuffleDeps = getAllDependencies(rdd) + .filter(_.isInstanceOf[ShuffleDependency[_, _, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _, _]]) + (rdd, shuffleDeps) + } + + protected def randomRdd() = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD() + case 1 => newShuffleRDD() + case 2 => newPairRDD.join(newPairRDD()) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + /** Run GC and make sure it actually has run */ + protected def runGC() { + val weakRef = new WeakReference(new Object()) + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + // Wait until a weak reference object has been GCed + while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + Thread.sleep(200) + } + } + + protected def cleaner = sc.cleaner.get +} + +/** + * Basic ContextCleanerSuite, which uses sort-based shuffle + */ +class ContextCleanerSuite extends ContextCleanerSuiteBase { test("cleanup RDD") { - val rdd = newRDD.persist() + val rdd = newRDD().persist() val collected = rdd.collect().toList val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) @@ -67,7 +132,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup shuffle") { - val (rdd, shuffleDeps) = newRDDWithShuffleDependencies + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies() val collected = rdd.collect().toList val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) @@ -80,7 +145,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup broadcast") { - val broadcast = newBroadcast + val broadcast = newBroadcast() val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) // Explicit cleanup @@ -89,7 +154,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup RDD") { - var rdd = newRDD.persist() + var rdd = newRDD().persist() rdd.count() // Test that GC does not cause RDD cleanup due to a strong reference @@ -107,7 +172,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup shuffle") { - var rdd = newShuffleRDD + var rdd = newShuffleRDD() rdd.count() // Test that GC does not cause shuffle cleanup due to a strong reference @@ -125,7 +190,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup broadcast") { - var broadcast = newBroadcast + var broadcast = newBroadcast() // Test that GC does not cause broadcast cleanup due to a strong reference val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) @@ -144,11 +209,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId - val broadcastIds = 0L until numBroadcasts + val broadcastIds = broadcastBuffer.map(_.id) val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() @@ -162,6 +227,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo rddBuffer.clear() runGC() postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) } test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { @@ -171,15 +243,16 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo .setMaster("local-cluster[2, 1, 512]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.shuffle.manager", shuffleManager.getName) sc = new SparkContext(conf2) val numRdds = 10 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId - val broadcastIds = 0L until numBroadcasts + val broadcastIds = broadcastBuffer.map(_.id) val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() @@ -193,57 +266,90 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo rddBuffer.clear() runGC() postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) } +} - //------ Helper functions ------ - def newRDD = sc.makeRDD(1 to 10) - def newPairRDD = newRDD.map(_ -> 1) - def newShuffleRDD = newPairRDD.reduceByKey(_ + _) - def newBroadcast = sc.broadcast(1 to 100) - def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { - def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { - rdd.dependencies ++ rdd.dependencies.flatMap { dep => - getAllDependencies(dep.rdd) - } - } - val rdd = newShuffleRDD +/** + * A copy of the shuffle tests for sort-based shuffle + */ +class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[SortShuffleManager]) { + test("cleanup shuffle") { + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies() + val collected = rdd.collect().toList + val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) - // Get all the shuffle dependencies - val shuffleDeps = getAllDependencies(rdd) - .filter(_.isInstanceOf[ShuffleDependency[_, _, _]]) - .map(_.asInstanceOf[ShuffleDependency[_, _, _]]) - (rdd, shuffleDeps) + // Explicit cleanup + shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true)) + tester.assertCleanup() + + // Verify that shuffles can be re-executed after cleaning up + assert(rdd.collect().toList.equals(collected)) } - def randomRdd = { - val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD - case 1 => newShuffleRDD - case 2 => newPairRDD.join(newPairRDD) - } - if (Random.nextBoolean()) rdd.persist() + test("automatically cleanup shuffle") { + var rdd = newShuffleRDD() rdd.count() - rdd - } - def randomBroadcast = { - sc.broadcast(Random.nextInt(Int.MaxValue)) + // Test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes shuffle cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope + runGC() + postGCTester.assertCleanup() } - /** Run GC and make sure it actually has run */ - def runGC() { - val weakRef = new WeakReference(new Object()) - val startTime = System.currentTimeMillis - System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - // Wait until a weak reference object has been GCed - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { - System.gc() - Thread.sleep(200) + test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { + sc.stop() + + val conf2 = new SparkConf() + .setMaster("local-cluster[2, 1, 512]") + .setAppName("ContextCleanerSuite") + .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.shuffle.manager", shuffleManager.getName) + sc = new SparkContext(conf2) + + val numRdds = 10 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast).toBuffer + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId() + val broadcastIds = broadcastBuffer.map(_.id) + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) } - } - def cleaner = sc.cleaner.get + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() + runGC() + postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) + } } @@ -401,6 +507,7 @@ class CleanerTester( private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { case ShuffleBlockId(`shuffleId`, _, _) => true + case ShuffleIndexBlockId(`shuffleId`, _, _) => true case _ => false }, askSlaves = true) } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index c70e22cf09433..4a53d25012ad9 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -24,12 +24,14 @@ import scala.io.Source import com.google.common.io.Files import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec -import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, TextOutputFormat} -import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.FunSuite import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} import org.apache.spark.util.Utils class FileSuite extends FunSuite with LocalSparkContext { @@ -318,4 +320,32 @@ class FileSuite extends FunSuite with LocalSparkContext { randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } + + test("Get input files via old Hadoop API") { + sc = new SparkContext("local", "test") + val outDir = new File(tempDir, "output").getAbsolutePath + sc.makeRDD(1 to 4, 2).saveAsTextFile(outDir) + + val inputPaths = + sc.hadoopFile(outDir, classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) + .asInstanceOf[HadoopRDD[_, _]] + .mapPartitionsWithInputSplit { (split, part) => + Iterator(split.asInstanceOf[FileSplit].getPath.toUri.getPath) + }.collect() + assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) + } + + test("Get input files via new Hadoop API") { + sc = new SparkContext("local", "test") + val outDir = new File(tempDir, "output").getAbsolutePath + sc.makeRDD(1 to 4, 2).saveAsTextFile(outDir) + + val inputPaths = + sc.newAPIHadoopFile(outDir, classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]) + .asInstanceOf[NewHadoopRDD[_, _]] + .mapPartitionsWithInputSplit { (split, part) => + Iterator(split.asInstanceOf[NewFileSplit].getPath.toUri.getPath) + }.collect() + assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) + } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 4658a08064280..fc0cee3e8749d 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import scala.collection.mutable.ArrayBuffer import scala.math.abs import org.scalatest.{FunSuite, PrivateMethodTester} @@ -52,14 +53,12 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(p2 === p2) assert(p4 === p4) - assert(p2 != p4) - assert(p4 != p2) + assert(p2 === p4) assert(p4 === anotherP4) assert(anotherP4 === p4) assert(descendingP2 === descendingP2) assert(descendingP4 === descendingP4) - assert(descendingP2 != descendingP4) - assert(descendingP4 != descendingP2) + assert(descendingP2 === descendingP4) assert(p2 != descendingP2) assert(p4 != descendingP4) assert(descendingP2 != p2) @@ -102,6 +101,63 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet partitioner.getPartition(Row(100)) } + test("RangPartitioner.sketch") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(i)(random.nextDouble()) + }.cache() + val sampleSizePerPartition = 10 + val (count, sketched) = RangePartitioner.sketch(rdd, sampleSizePerPartition) + assert(count === rdd.count()) + sketched.foreach { case (idx, n, sample) => + assert(n === idx) + assert(sample.size === math.min(n, sampleSizePerPartition)) + } + } + + test("RangePartitioner.determineBounds") { + assert(RangePartitioner.determineBounds(ArrayBuffer.empty[(Int, Float)], 10).isEmpty, + "Bounds on an empty candidates set should be empty.") + val candidates = ArrayBuffer( + (0.7, 2.0f), (0.1, 1.0f), (0.4, 1.0f), (0.3, 1.0f), (0.2, 1.0f), (0.5, 1.0f), (1.0, 3.0f)) + assert(RangePartitioner.determineBounds(candidates, 3) === Array(0.4, 0.7)) + } + + test("RangePartitioner should run only one job if data is roughly balanced") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(5000 * i)((random.nextDouble() + i, i)) + }.cache() + for (numPartitions <- Seq(10, 20, 40)) { + val partitioner = new RangePartitioner(numPartitions, rdd) + assert(partitioner.numPartitions === numPartitions) + val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values + assert(counts.max < 3.0 * counts.min) + } + } + + test("RangePartitioner should work well on unbalanced data") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(20 * i * i * i)((random.nextDouble() + i, i)) + }.cache() + for (numPartitions <- Seq(2, 4, 8)) { + val partitioner = new RangePartitioner(numPartitions, rdd) + assert(partitioner.numPartitions === numPartitions) + val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values + assert(counts.max < 3.0 * counts.min) + } + } + + test("RangePartitioner should return a single partition for empty RDDs") { + val empty1 = sc.emptyRDD[(Int, Double)] + val partitioner1 = new RangePartitioner(0, empty1) + assert(partitioner1.numPartitions === 1) + val empty2 = sc.makeRDD(0 until 2, 2).flatMap(i => Seq.empty[(Int, Double)]) + val partitioner2 = new RangePartitioner(2, empty2) + assert(partitioner2.numPartitions === 1) + } + test("HashPartitioner not equal to RangePartitioner") { val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd) diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index 47df00050c1e2..d7b2d2e1e330f 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -28,6 +28,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { } override def afterAll() { - System.setProperty("spark.shuffle.use.netty", "false") + System.clearProperty("spark.shuffle.use.netty") } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index eae67c7747e82..b13ddf96bc77c 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -58,8 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, - NonJavaSerializableClass, - (Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS)) + NonJavaSerializableClass](b, new HashPartitioner(NUM_BLOCKS)) c.setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId @@ -83,8 +82,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, - NonJavaSerializableClass, - (Int, NonJavaSerializableClass)](b, new HashPartitioner(3)) + NonJavaSerializableClass](b, new HashPartitioner(3)) c.setSerializer(new KryoSerializer(conf)) assert(c.count === 10) } @@ -100,7 +98,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo - val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId @@ -126,7 +124,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) @@ -141,19 +139,19 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } - test("shuffle using mutable pairs") { + test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test") def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) - val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs, + val results = new ShuffledRDD[Int, Int, Int](pairs, new HashPartitioner(2)).collect() - data.foreach { pair => results should contain (pair) } + data.foreach { pair => results should contain ((pair._1, pair._2)) } } - test("sorting using mutable pairs") { + test("sorting on mutable pairs") { // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test") @@ -162,10 +160,10 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) .sortByKey().collect() - results(0) should be (p(1, 11)) - results(1) should be (p(2, 22)) - results(2) should be (p(3, 33)) - results(3) should be (p(100, 100)) + results(0) should be ((1, 11)) + results(1) should be ((2, 22)) + results(2) should be ((3, 33)) + results(3) should be ((100, 100)) } test("cogroup using mutable pairs") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala similarity index 60% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala rename to core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 4589129cd1c90..5c02c00586ef4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -17,20 +17,18 @@ package org.apache.spark -/** - * Allows the execution of relational queries, including those expressed in SQL using Spark. - * - * Note that this package is located in catalyst instead of in core so that all subprojects can - * inherit the settings from this package object. - */ -package object sql { +import org.scalatest.BeforeAndAfterAll - protected[sql] def Logger(name: String) = - com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger(name)) +class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { - protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging + // This test suite should run all tests in ShuffleSuite with sort-based shuffle. - type Row = catalyst.expressions.Row + override def beforeAll() { + System.setProperty("spark.shuffle.manager", + "org.apache.spark.shuffle.sort.SortShuffleManager") + } - val Row = catalyst.expressions.Row + override def afterAll() { + System.clearProperty("spark.shuffle.manager") + } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 67e3be21c3c93..495a0d48633a4 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterEach, FunSuite, PrivateMethodTester} import org.apache.spark.scheduler.{TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} @@ -25,12 +25,12 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite - extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { + extends FunSuite with PrivateMethodTester with Logging with BeforeAndAfterEach { def createTaskScheduler(master: String): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. - sc = new SparkContext("local", "test") + val sc = new SparkContext("local", "test") val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) sched.asInstanceOf[TaskSchedulerImpl] @@ -68,6 +68,15 @@ class SparkContextSchedulerCreationSuite } } + test("local-*-n-failures") { + val sched = createTaskScheduler("local[* ,2]") + assert(sched.maxTaskFailures === 2) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors()) + case _ => fail() + } + } + test("local-n-failures") { val sched = createTaskScheduler("local[4, 2]") assert(sched.maxTaskFailures === 2) @@ -77,6 +86,20 @@ class SparkContextSchedulerCreationSuite } } + test("bad-local-n") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("bad-local-n-failures") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*,4]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + test("local-default-parallelism") { val defaultParallelism = System.getProperty("spark.default.parallelism") System.setProperty("spark.default.parallelism", "16") diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 01ab2d549325c..093394ad6d142 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -88,7 +88,7 @@ class JsonProtocolSuite extends FunSuite { } def createAppDesc(): ApplicationDescription = { - val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq()) + val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) new ApplicationDescription("name", Some(4), 1234, cmd, Some("sparkHome"), "appUiUrl") } @@ -101,7 +101,7 @@ class JsonProtocolSuite extends FunSuite { def createDriverCommand() = new Command( "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), - Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Some("-Dfoo") + Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3, @@ -170,7 +170,7 @@ object JsonConstants { """ |{"name":"name","cores":4,"memoryperslave":1234, |"user":"%s","sparkhome":"sparkHome", - |"command":"Command(mainClass,List(arg1, arg2),Map(),List(),List(),None)"} + |"command":"Command(mainClass,List(arg1, arg2),Map(),List(),List(),List())"} """.format(System.getProperty("user.name", "")).stripMargin val executorRunnerJsonStr = diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index f497a5e0a14f0..9190b05e2dba2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -200,9 +200,12 @@ class SparkSubmitSuite extends FunSuite with Matchers { childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") mainClass should be ("org.apache.spark.deploy.Client") classpath should have size (0) - sysProps should have size (3) - sysProps.keys should contain ("spark.jars") + sysProps should have size (5) sysProps.keys should contain ("SPARK_SUBMIT") + sysProps.keys should contain ("spark.master") + sysProps.keys should contain ("spark.app.name") + sysProps.keys should contain ("spark.jars") + sysProps.keys should contain ("spark.shuffle.spill") sysProps("spark.shuffle.spill") should be ("false") } @@ -250,6 +253,22 @@ class SparkSubmitSuite extends FunSuite with Matchers { sysProps("spark.shuffle.spill") should be ("false") } + test("handles confs with flag equivalents") { + val clArgs = Seq( + "--deploy-mode", "cluster", + "--executor-memory", "5g", + "--class", "org.SomeClass", + "--conf", "spark.executor.memory=4g", + "--conf", "spark.master=yarn", + "thejar.jar", + "arg1", "arg2") + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, sysProps, mainClass) = createLaunchEnv(appArgs) + sysProps("spark.executor.memory") should be ("5g") + sysProps("spark.master") should be ("yarn-cluster") + mainClass should be ("org.apache.spark.deploy.yarn.Client") + } + test("launch simple application with spark-submit") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 4633bc3f7f25e..c930839b47f11 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -29,7 +29,7 @@ import org.apache.spark.deploy.{Command, DriverDescription} class DriverRunnerTest extends FunSuite { private def createDriverRunner() = { - val command = new Command("mainClass", Seq(), Map(), Seq(), Seq()) + val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription, null, "akka://1.2.3.4/worker/") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index e5f748d55500d..ca4d987619c91 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -29,7 +29,7 @@ class ExecutorRunnerTest extends FunSuite { def f(s:String) = new File(s) val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")) val appDesc = new ApplicationDescription("app name", Some(8), 500, - Command("foo", Seq(), Map(), Seq(), Seq()), + Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), sparkHome, "appUiUrl") val appId = "12345-worker321-9876" val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome.getOrElse(".")), diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 447e38ec9dbd0..4f49d4a1d4d34 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -83,6 +83,122 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { assert(valuesFor2.toList.sorted === List(1)) } + test("sampleByKey") { + def stratifier (fractionPositive: Double) = { + (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" + } + + def checkSize(exact: Boolean, + withReplacement: Boolean, + expected: Long, + actual: Long, + p: Double): Boolean = { + if (exact) { + return expected == actual + } + val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + math.abs(actual - expected) <= 6 * stdev + } + + // Without replacement validation + def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey() + .mapValues(count => math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = stratifiedData.sampleByKey(false, fractions, exact, seed) + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case(k, v) => + assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + assert(takeSample.size === takeSample.toSet.size) + takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + } + + // With replacement validation + def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey().mapValues(count => + math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = stratifiedData.sampleByKey(true, fractions, exact, seed) + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case(k, v) => + assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) } + val groupedByKey = takeSample.groupBy(_._1) + for ((key, v) <- groupedByKey) { + if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { + // sample large enough for there to be repeats with high likelihood + assert(v.toSet.size < expectedSampleSize(key)) + } else { + if (exact) { + assert(v.toSet.size <= expectedSampleSize(key)) + } else { + assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + } + } + } + takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + } + + def checkAllCombos(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n) + takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n) + takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n) + takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n) + } + + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 1000000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(stratifier(fractionPositive)) + + val samplingRate = 0.1 + checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + } + + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(stratifier(fractionPositive)) + + val samplingRate = 0.1 + checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + } + + // Use the same data for the rest of the tests + val fractionPositive = 0.3 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(stratifier(fractionPositive)) + + // vary seed + for (seed <- defaultSeed to defaultSeed + 5L) { + val samplingRate = 0.1 + checkAllCombos(stratifiedData, samplingRate, seed, n) + } + + // vary sampling rate + for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { + checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + } + } + test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_+_).collect() diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6654ec2d7c656..b31e3a09e5b9c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.scalatest.FunSuite @@ -26,6 +27,7 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.util.Utils +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ class RDDSuite extends FunSuite with SharedSparkContext { @@ -121,6 +123,18 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(union.partitioner === nums1.partitioner) } + test("UnionRDD partition serialized size should be small") { + val largeVariable = new Array[Byte](1000 * 1000) + val rdd1 = sc.parallelize(1 to 10, 2).map(i => largeVariable.length) + val rdd2 = sc.parallelize(1 to 10, 3) + + val ser = SparkEnv.get.closureSerializer.newInstance() + val union = rdd1.union(rdd2) + // The UnionRDD itself should be large, but each individual partition should be small. + assert(ser.serialize(union).limit() > 2000) + assert(ser.serialize(union.partitions.head).limit() < 2000) + } + test("aggregate") { val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] @@ -155,19 +169,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { override def getPartitions: Array[Partition] = Array(onlySplit) override val getDependencies = List[Dependency[_]]() override def compute(split: Partition, context: TaskContext): Iterator[Int] = { - if (shouldFail) { - throw new Exception("injected failure") - } else { - Array(1, 2, 3, 4).iterator - } + throw new Exception("injected failure") } }.cache() val thrown = intercept[Exception]{ rdd.collect() } assert(thrown.getMessage.contains("injected failure")) - shouldFail = false - assert(rdd.collect().toList === List(1, 2, 3, 4)) } test("empty RDD") { @@ -276,7 +284,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { // we can optionally shuffle to keep the upstream parallel val coalesced5 = data.coalesce(1, shuffle = true) val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd. - asInstanceOf[ShuffledRDD[_, _, _, _]] != null + asInstanceOf[ShuffledRDD[_, _, _]] != null assert(isEquals) // when shuffling, we can increase the number of partitions @@ -613,6 +621,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("sort an empty RDD") { + val data = sc.emptyRDD[Int] + assert(data.sortBy(x => x).collect() === Array.empty) + } + test("sortByKey") { val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) @@ -707,6 +720,12 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(ids.length === n) } + test("retag with implicit ClassTag") { + val jsc: JavaSparkContext = new JavaSparkContext(sc) + val jrdd: JavaRDD[String] = jsc.parallelize(Seq("A", "B", "C").asJava) + jrdd.rdd.retag.collect() + } + test("getNarrowAncestors") { val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.filter(_ % 2 == 0).map(_ + 1) @@ -731,9 +750,9 @@ class RDDSuite extends FunSuite with SharedSparkContext { // Any ancestors before the shuffle are not considered assert(ancestors4.size === 0) - assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0) + assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 0) assert(ancestors5.size === 3) - assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1) + assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1) assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0) assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9021662bcf712..36e238b4c9434 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -23,12 +23,15 @@ import scala.language.reflectiveCalls import akka.actor._ import akka.testkit.{ImplicitSender, TestKit, TestActorRef} import org.scalatest.{BeforeAndAfter, FunSuiteLike} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite +import org.apache.spark.executor.TaskMetrics class BuggyDAGEventProcessActor extends Actor { val state = 0 @@ -63,7 +66,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike - with ImplicitSender with BeforeAndAfter with LocalSparkContext { + with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -77,6 +80,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def schedulingMode: SchedulingMode = SchedulingMode.NONE override def start() = {} override def stop() = {} + override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) @@ -291,6 +296,18 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } + test("avoid exponential blowup when getting preferred locs list") { + // Build up a complex dependency graph with repeated zip operations, without preferred locations. + var rdd: RDD[_] = new MyRDD(sc, 1, Nil) + (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) + // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. + failAfter(10 seconds) { + val preferredLocs = scheduler.getPreferredLocs(rdd,0) + // No preferred locations are returned. + assert(preferredLocs.length === 0) + } + } + test("unserializable task") { val unserializableRdd = new MyRDD(sc, 1, Nil) { class UnserializableClass @@ -342,6 +359,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 + override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true } val noKillScheduler = new DAGScheduler( sc, diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 21e3db34b8b7a..10d8b299317ea 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -259,7 +259,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val expectedLogDir = logDirPath.toString - assert(eventLogger.logDir.startsWith(expectedLogDir)) + assert(eventLogger.logDir.contains(expectedLogDir)) // Begin listening for events that trigger asserts val eventExistenceListener = new EventExistenceListener(eventLogger) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 8bb5317cd2875..270f7e661045a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -20,31 +20,35 @@ package org.apache.spark.scheduler import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter -import org.apache.spark.LocalSparkContext -import org.apache.spark.Partition -import org.apache.spark.SparkContext -import org.apache.spark.TaskContext +import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { test("Calls executeOnCompleteCallbacks after failure") { - var completed = false + TaskContextSuite.completed = false sc = new SparkContext("local", "test") val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => completed = true) + context.addOnCompleteCallback(() => TaskContextSuite.completed = true) sys.error("failed") } } - val func = (c: TaskContext, i: Iterator[String]) => i.next - val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val task = new ResultTask[String, String]( + 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0) } - assert(completed === true) + assert(TaskContextSuite.completed === true) } +} - case class StubPartition(val index: Int) extends Partition +private object TaskContextSuite { + @volatile var completed = false } + +private case class StubPartition(index: Int) extends Partition diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 79280d1a06653..789b773bae316 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -209,6 +209,36 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } } +class KryoSerializerResizableOutputSuite extends FunSuite { + import org.apache.spark.SparkConf + import org.apache.spark.SparkContext + import org.apache.spark.LocalSparkContext + import org.apache.spark.SparkException + + // trial and error showed this will not serialize with 1mb buffer + val x = (1 to 400000).toArray + + test("kryo without resizable output buffer should fail on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "1") + val sc = new SparkContext("local", "test", conf) + intercept[SparkException](sc.parallelize(x).collect) + LocalSparkContext.stop(sc) + } + + test("kryo with resizable output buffer should succeed on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "2") + val sc = new SparkContext("local", "test", conf) + assert(sc.parallelize(x).collect === x) + LocalSparkContext.stop(sc) + } +} + object KryoTest { case class CaseClass(i: Int, s: String) {} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala new file mode 100644 index 0000000000000..8dca2ebb312f5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.{FunSuite, Matchers} +import org.scalatest.PrivateMethodTester._ + +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.stubbing.Answer +import org.mockito.invocation.InvocationOnMock + +import org.apache.spark._ +import org.apache.spark.storage.BlockFetcherIterator._ +import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, + Message} + +class BlockFetcherIteratorSuite extends FunSuite with Matchers { + + test("block fetch from local fails using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + doReturn(connManager).when(blockManager).connectionManager + doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId + + doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + val answer = new Answer[Option[Iterator[Any]]] { + override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { + throw new Exception + } + } + + // 3rd block is going to fail + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) + doAnswer(answer).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + + // 3rd getLocalFromDisk invocation should be failed + verify(blockManager, times(3)).getLocalFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully + assert(iterator.next._2.isDefined, "1st element should be defined but is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next._2.isDefined, "2nd element should be defined but is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + // 3rd fetch should be failed + assert(!iterator.next._2.isDefined, "3rd element should not be defined but is actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") + // Don't call next() after fetching non-defined element even if thare are rest of elements in the iterator. + // Otherwise, BasicBlockFetcherIterator hangs up. + } + + + test("block fetch from local succeed using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + doReturn(connManager).when(blockManager).connectionManager + doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId + + doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + + // All blocks should be fetched successfully + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + + // getLocalFromDis should be invoked for all of 5 blocks + verify(blockManager, times(5)).getLocalFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next._2.isDefined, "All elements should be defined but 2nd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 3rd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 4th element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index dd4fd535d3577..0ac0269d7cfc1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,25 +19,28 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays +import java.util.concurrent.TimeUnit import akka.actor._ -import org.apache.spark.SparkConf -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} +import akka.pattern.ask +import akka.util.Timeout + import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers -import org.scalatest.time.SpanSugar._ import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Await +import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps @@ -76,7 +79,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter oldArch = System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.storage.disableBlockManagerHeartBeat", "true") conf.set("spark.driver.port", boundPort.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -344,7 +346,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("reregistration on heart beat") { - val heartBeat = PrivateMethod[Unit]('heartBeat) store = makeBlockManager(2000) val a1 = new Array[Byte](400) @@ -356,13 +357,15 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - store invokePrivate heartBeat() - assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + implicit val timeout = Timeout(30, TimeUnit.SECONDS) + val reregister = !Await.result( + master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), + timeout.duration).asInstanceOf[Boolean] + assert(reregister == true) } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -380,7 +383,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("reregistration doesn't dead lock") { - val heartBeat = PrivateMethod[Unit]('heartBeat) store = makeBlockManager(2000) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -400,7 +402,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } val t3 = new Thread { override def run() { - store invokePrivate heartBeat() + store.reregister() } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index aaa7714049732..985ac9394738c 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -22,11 +22,14 @@ import java.io.{File, FileWriter} import scala.collection.mutable import scala.language.reflectiveCalls +import akka.actor.Props import com.google.common.io.Files import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.util.Utils +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.{AkkaUtils, Utils} class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -121,6 +124,88 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before newFile.delete() } + private def checkSegments(segment1: FileSegment, segment2: FileSegment) { + assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath) + assert (segment1.offset === segment2.offset) + assert (segment1.length === segment2.length) + } + + test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { + + val serializer = new JavaSerializer(testConf) + val confCopy = testConf.clone + // reset after EACH object write. This is to ensure that there are bytes appended after + // an object is written. So if the codepaths assume writeObject is end of data, this should + // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. + confCopy.set("spark.serializer.objectStreamReset", "1") + + val securityManager = new org.apache.spark.SecurityManager(confCopy) + // Do not use the shuffleBlockManager above ! + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, confCopy, + securityManager) + val master = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), + confCopy) + val store = new BlockManager("", actorSystem, master , serializer, confCopy, + securityManager, null) + + try { + + val shuffleManager = store.shuffleBlockManager + + val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer) + for (writer <- shuffle1.writers) { + writer.write("test1") + writer.write("test2") + } + for (writer <- shuffle1.writers) { + writer.commitAndClose() + } + + val shuffle1Segment = shuffle1.writers(0).fileSegment() + shuffle1.releaseWriters(success = true) + + val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf)) + + for (writer <- shuffle2.writers) { + writer.write("test3") + writer.write("test4") + } + for (writer <- shuffle2.writers) { + writer.commitAndClose() + } + val shuffle2Segment = shuffle2.writers(0).fileSegment() + shuffle2.releaseWriters(success = true) + + // Now comes the test : + // Write to shuffle 3; and close it, but before registering it, check if the file lengths for + // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length + // of block based on remaining data in file : which could mess things up when there is concurrent read + // and writes happening to the same shuffle group. + + val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf)) + for (writer <- shuffle3.writers) { + writer.write("test3") + writer.write("test4") + } + for (writer <- shuffle3.writers) { + writer.commitAndClose() + } + // check before we register. + checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0))) + shuffle3.releaseWriters(success = true) + checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0))) + shuffleManager.removeShuffle(1) + } finally { + + if (store != null) { + store.stop() + } + actorSystem.shutdown() + actorSystem.awaitTermination() + } + } + def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) { val segment = diskBlockManager.getBlockLocation(blockId) assert(segment.file.getName === filename) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index b52f81877d557..cb8252515238e 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -21,11 +21,13 @@ import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark._ -import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} +import org.apache.spark.{LocalSparkContext, SparkConf, Success} +import org.apache.spark.executor.{ShuffleWriteMetrics, ShuffleReadMetrics, TaskMetrics} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { + test("test LRU eviction of stages") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) @@ -66,7 +68,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - var task = new ShuffleMapTask(0, null, null, 0, null) + var task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) @@ -76,14 +78,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo = new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.size === 1) // finish this task, should get updated duration taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) .shuffleRead === 2000) @@ -91,7 +93,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated duration taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail()) .shuffleRead === 1000) @@ -103,7 +105,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val metrics = new TaskMetrics() val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - val task = new ShuffleMapTask(0, null, null, 0, null) + val task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) // Go through all the failure cases to make sure we are counting them as failures. @@ -128,4 +130,87 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1) assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) } + + test("test update metrics") { + val conf = new SparkConf() + val listener = new JobProgressListener(conf) + + val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) + val execId = "exe-1" + + def makeTaskMetrics(base: Int) = { + val taskMetrics = new TaskMetrics() + val shuffleReadMetrics = new ShuffleReadMetrics() + val shuffleWriteMetrics = new ShuffleWriteMetrics() + taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) + shuffleReadMetrics.remoteBytesRead = base + 1 + shuffleReadMetrics.remoteBlocksFetched = base + 2 + shuffleWriteMetrics.shuffleBytesWritten = base + 3 + taskMetrics.executorRunTime = base + 4 + taskMetrics.diskBytesSpilled = base + 5 + taskMetrics.memoryBytesSpilled = base + 6 + taskMetrics + } + + def makeTaskInfo(taskId: Long, finishTime: Int = 0) = { + val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, + false) + taskInfo.finishTime = finishTime + taskInfo + } + + listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1234L))) + listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1235L))) + listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1236L))) + listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1237L))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( + (1234L, 0, makeTaskMetrics(0)), + (1235L, 0, makeTaskMetrics(100)), + (1236L, 1, makeTaskMetrics(200))))) + + var stage0Data = listener.stageIdToData.get(0).get + var stage1Data = listener.stageIdToData.get(1).get + assert(stage0Data.shuffleReadBytes == 102) + assert(stage1Data.shuffleReadBytes == 201) + assert(stage0Data.shuffleWriteBytes == 106) + assert(stage1Data.shuffleWriteBytes == 203) + assert(stage0Data.executorRunTime == 108) + assert(stage1Data.executorRunTime == 204) + assert(stage0Data.diskBytesSpilled == 110) + assert(stage1Data.diskBytesSpilled == 205) + assert(stage0Data.memoryBytesSpilled == 112) + assert(stage1Data.memoryBytesSpilled == 206) + assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 2) + assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 102) + assert(stage1Data.taskData.get(1236L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 202) + + // task that was included in a heartbeat + listener.onTaskEnd(SparkListenerTaskEnd(0, taskType, Success, makeTaskInfo(1234L, 1), + makeTaskMetrics(300))) + // task that wasn't included in a heartbeat + listener.onTaskEnd(SparkListenerTaskEnd(1, taskType, Success, makeTaskInfo(1237L, 1), + makeTaskMetrics(400))) + + stage0Data = listener.stageIdToData.get(0).get + stage1Data = listener.stageIdToData.get(1).get + assert(stage0Data.shuffleReadBytes == 402) + assert(stage1Data.shuffleReadBytes == 602) + assert(stage0Data.shuffleWriteBytes == 406) + assert(stage1Data.shuffleWriteBytes == 606) + assert(stage0Data.executorRunTime == 408) + assert(stage1Data.executorRunTime == 608) + assert(stage0Data.diskBytesSpilled == 410) + assert(stage1Data.diskBytesSpilled == 610) + assert(stage0Data.memoryBytesSpilled == 412) + assert(stage1Data.memoryBytesSpilled == 612) + assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 302) + assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 402) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 0b7ad184a46d2..7de5df6e1c8bd 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -208,11 +208,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val resultA = rddA.reduceByKey(math.max).collect() assert(resultA.length == 50000) resultA.foreach { case(k, v) => - k match { - case 0 => assert(v == 1) - case 25000 => assert(v == 50001) - case 49999 => assert(v == 99999) - case _ => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") } } @@ -221,11 +218,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val resultB = rddB.groupByKey().collect() assert(resultB.length == 25000) resultB.foreach { case(i, seq) => - i match { - case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) - case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003)) - case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999)) - case _ => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") } } @@ -239,6 +234,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { case 0 => assert(seq1.toSet == Set[Int](0)) assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) case 5000 => assert(seq1.toSet == Set[Int](5000)) assert(seq2.toSet == Set[Int]()) @@ -369,10 +367,3 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } } - -/** - * A dummy class that always returns the same hash code, to easily test hash collisions - */ -case class FixedHashObject(v: Int, h: Int) extends Serializable { - override def hashCode(): Int = h -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala new file mode 100644 index 0000000000000..65a71e5a83698 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -0,0 +1,576 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +class ExternalSorterSuite extends FunSuite with LocalSparkContext { + test("empty data stream") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter.iterator.toSeq === Seq()) + sorter.stop() + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(3)), None, None) + assert(sorter2.iterator.toSeq === Seq()) + sorter2.stop() + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter3.iterator.toSeq === Seq()) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), None, None) + assert(sorter4.iterator.toSeq === Seq()) + sorter4.stop() + } + + test("few elements per partition") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val elements = Set((1, 1), (2, 2), (5, 5)) + val expected = Set( + (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()), + (5, Set((5, 5))), (6, Set())) + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(7)), Some(ord), None) + sorter.write(elements.iterator) + assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter.stop() + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(7)), None, None) + sorter2.write(elements.iterator) + assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter2.stop() + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), Some(ord), None) + sorter3.write(elements.iterator) + assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), None, None) + sorter4.write(elements.iterator) + assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter4.stop() + } + + test("empty partitions with spilling") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), None, None) + sorter.write(elements) + assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled + val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) + assert(iter.next() === (0, Nil)) + assert(iter.next() === (1, List((1, 1)))) + assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) + assert(iter.next() === (3, Nil)) + assert(iter.next() === (4, Nil)) + assert(iter.next() === (5, List((5, 5)))) + assert(iter.next() === (6, Nil)) + sorter.stop() + } + + test("spilling in local cluster") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + // reduceByKey - should spill ~8 times + val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) + val resultA = rddA.reduceByKey(math.max).collect() + assert(resultA.length == 50000) + resultA.foreach { case(k, v) => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") + } + } + + // groupByKey - should spill ~17 times + val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultB = rddB.groupByKey().collect() + assert(resultB.length == 25000) + resultB.foreach { case(i, seq) => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") + } + } + + // cogroup - should spill ~7 times + val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) + val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) + val resultC = rddC1.cogroup(rddC2).collect() + assert(resultC.length == 10000) + resultC.foreach { case(i, (seq1, seq2)) => + i match { + case 0 => + assert(seq1.toSet == Set[Int](0)) + assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) + case 5000 => + assert(seq1.toSet == Set[Int](5000)) + assert(seq2.toSet == Set[Int]()) + case 9999 => + assert(seq1.toSet == Set[Int](9999)) + assert(seq2.toSet == Set[Int]()) + case _ => + } + } + + // larger cogroup - should spill ~7 times + val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultD = rddD1.cogroup(rddD2).collect() + assert(resultD.length == 5000) + resultD.foreach { case(i, (seq1, seq2)) => + val expected = Set(i * 2, i * 2 + 1) + if (seq1.toSet != expected) { + fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") + } + if (seq2.toSet != expected) { + fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") + } + } + + // sortByKey - should spill ~17 times + val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultE = rddE.sortByKey().collect().toSeq + assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) + } + + test("spilling in local cluster with many reduce tasks") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + + // reduceByKey - should spill ~4 times per executor + val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) + val resultA = rddA.reduceByKey(math.max _, 100).collect() + assert(resultA.length == 50000) + resultA.foreach { case(k, v) => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") + } + } + + // groupByKey - should spill ~8 times per executor + val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultB = rddB.groupByKey(100).collect() + assert(resultB.length == 25000) + resultB.foreach { case(i, seq) => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") + } + } + + // cogroup - should spill ~4 times per executor + val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) + val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) + val resultC = rddC1.cogroup(rddC2, 100).collect() + assert(resultC.length == 10000) + resultC.foreach { case(i, (seq1, seq2)) => + i match { + case 0 => + assert(seq1.toSet == Set[Int](0)) + assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) + case 5000 => + assert(seq1.toSet == Set[Int](5000)) + assert(seq2.toSet == Set[Int]()) + case 9999 => + assert(seq1.toSet == Set[Int](9999)) + assert(seq2.toSet == Set[Int]()) + case _ => + } + } + + // larger cogroup - should spill ~4 times per executor + val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultD = rddD1.cogroup(rddD2).collect() + assert(resultD.length == 5000) + resultD.foreach { case(i, (seq1, seq2)) => + val expected = Set(i * 2, i * 2 + 1) + if (seq1.toSet != expected) { + fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") + } + if (seq2.toSet != expected) { + fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") + } + } + + // sortByKey - should spill ~8 times per executor + val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultE = rddE.sortByKey().collect().toSeq + assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) + } + + test("cleanup of intermediate files in sorter") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + + val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + sorter2.write((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) + sorter2.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in sorter if there are errors") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + intercept[SparkException] { + sorter.write((0 until 100000).iterator.map(i => { + if (i == 99990) { + throw new SparkException("Intentional failure") + } + (i, i) + })) + } + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in shuffle") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + + val data = sc.parallelize(0 until 100000, 2).map(i => (i, i)) + assert(data.reduceByKey(_ + _).count() === 100000) + + // After the shuffle, there should be only 4 files on disk: our two map output files and + // their index files. All other intermediate files should've been deleted. + assert(diskBlockManager.getAllFiles().length === 4) + } + + test("cleanup of intermediate files in shuffle with errors") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + + val data = sc.parallelize(0 until 100000, 2).map(i => { + if (i == 99990) { + throw new Exception("Intentional failure") + } + (i, i) + }) + intercept[SparkException] { + data.reduceByKey(_ + _).count() + } + + // After the shuffle, there should be only 2 files on disk: the output of task 1 and its index. + // All other files (map 2's output and intermediate merge files) should've been deleted. + assert(diskBlockManager.getAllFiles().length === 2) + } + + test("no partial aggregation or sorting") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100000).iterator.map(i => (i / 4, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation without spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation with spill, no ordering") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation with spill, with ordering") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("sorting without aggregation, no spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100).iterator.map(i => (i, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq + val expected = (0 until 3).map(p => { + (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) + }).toSeq + assert(results === expected) + } + + test("sorting without aggregation, with spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100000).iterator.map(i => (i, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq + val expected = (0 until 3).map(p => { + (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) + }).toSeq + assert(results === expected) + } + + test("spilling with hash collisions") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + def createCombiner(i: String) = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i + def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) = + buffer1 ++= buffer2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner _, mergeValue _, mergeCombiners _) + + val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( + Some(agg), None, None, None) + + val collisionPairs = Seq( + ("Aa", "BB"), // 2112 + ("to", "v1"), // 3707 + ("variants", "gelato"), // -1249574770 + ("Teheran", "Siblings"), // 231609873 + ("misused", "horsemints"), // 1069518484 + ("isohel", "epistolaries"), // -1179291542 + ("righto", "buzzards"), // -931102253 + ("hierarch", "crinolines"), // -1732884796 + ("inwork", "hypercatalexes"), // -1183663690 + ("wainages", "presentencing"), // 240183619 + ("trichothecenes", "locular"), // 339006536 + ("pomatoes", "eructation") // 568647356 + ) + + collisionPairs.foreach { case (w1, w2) => + // String.hashCode is documented to use a specific algorithm, but check just in case + assert(w1.hashCode === w2.hashCode) + } + + val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++ + collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap) + + sorter.write(toInsert) + + // A map of collision pairs in both directions + val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap + + // Avoid map.size or map.iterator.length because this destructively sorts the underlying map + var count = 0 + + val it = sorter.iterator + while (it.hasNext) { + val kv = it.next() + val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1)) + assert(kv._2.equals(expectedValue)) + count += 1 + } + assert(count === 100000 + collisionPairs.size * 2) + } + + test("spilling with many hash collisions") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.0001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) + val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) + + // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes + // problems if the map fails to group together the objects with the same code (SPARK-2043). + val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1) + sorter.write(toInsert.iterator) + + val it = sorter.iterator + var count = 0 + while (it.hasNext) { + val kv = it.next() + assert(kv._2 === 10) + count += 1 + } + assert(count === 10000) + } + + test("spilling with hash collisions using the Int.MaxValue key") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + def createCombiner(i: Int) = ArrayBuffer[Int](i) + def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i + def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + + val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) + val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) + + sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + + val it = sorter.iterator + while (it.hasNext) { + // Should not throw NoSuchElementException + it.next() + } + } + + test("spilling with null keys and values") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + def createCombiner(i: String) = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i + def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner, mergeValue, mergeCombiners) + + val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( + Some(agg), None, None, None) + + sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( + (null.asInstanceOf[String], "1"), + ("1", null.asInstanceOf[String]), + (null.asInstanceOf[String], null.asInstanceOf[String]) + )) + + val it = sorter.iterator + while (it.hasNext) { + // Should not throw NullPointerException + it.next() + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala similarity index 77% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala rename to core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala index 582334aa42590..c787b5f066e00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala @@ -15,10 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.plans.logical +package org.apache.spark.util.collection -abstract class BaseRelation extends LeafNode { - self: Product => - - def tableName: String +/** + * A dummy class that always returns the same hash code, to easily test hash collisions + */ +case class FixedHashObject(v: Int, h: Int) extends Serializable { + override def hashCode(): Int = h } diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index 7257d17d10116..a21410f3b9813 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{ListBuffer, Queue} import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.hive.LocalHiveContext +import org.apache.spark.sql.hive.HiveContext case class Person(name: String, age: Int) @@ -34,7 +34,7 @@ object SparkSqlExample { case None => new SparkConf().setAppName("Simple Sql App") } val sc = new SparkContext(conf) - val hiveContext = new LocalHiveContext(sc) + val hiveContext = new HiveContext(sc) import hiveContext._ hql("DROP TABLE IF EXISTS src") diff --git a/dev/check-license b/dev/check-license index fbd2dd465bb18..625ec161bc571 100755 --- a/dev/check-license +++ b/dev/check-license @@ -27,23 +27,23 @@ acquire_rat_jar () { if [[ ! -f "$rat_jar" ]]; then # Download rat launch jar if it hasn't been downloaded yet - if [ ! -f ${JAR} ]; then + if [ ! -f "$JAR" ]; then # Download printf "Attempting to fetch rat\n" JAR_DL=${JAR}.part if hash curl 2>/dev/null; then - (curl --progress-bar ${URL1} > ${JAR_DL} || curl --progress-bar ${URL2} > ${JAR_DL}) && mv ${JAR_DL} ${JAR} + (curl --progress-bar ${URL1} > "$JAR_DL" || curl --progress-bar ${URL2} > "$JAR_DL") && mv "$JAR_DL" "$JAR" elif hash wget 2>/dev/null; then - (wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR} + (wget --progress=bar ${URL1} -O "$JAR_DL" || wget --progress=bar ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR" else printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 fi fi - if [ ! -f ${JAR} ]; then - # We failed to download - printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" - exit -1 + if [ ! -f "$JAR" ]; then + # We failed to download + printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + exit -1 fi printf "Launching rat from ${JAR}\n" fi @@ -51,7 +51,7 @@ acquire_rat_jar () { # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" -cd $FWDIR +cd "$FWDIR" if test -x "$JAVA_HOME/bin/java"; then declare java_cmd="$JAVA_HOME/bin/java" diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 38830103d1e8d..af46572e6602b 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -53,7 +53,7 @@ if [[ ! "$@" =~ --package-only ]]; then -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ --batch-mode release:prepare @@ -61,7 +61,7 @@ if [[ ! "$@" =~ --package-only ]]; then -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ release:perform cd .. @@ -111,10 +111,12 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" make_binary_release "hadoop2" \ - "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" +make_binary_release "hadoop2-without-hive" \ + "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" # Copy data echo "Copying release tarballs" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index c44320239bbbf..53df9b5a3f1d5 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -29,7 +29,6 @@ import re import subprocess import sys -import tempfile import urllib2 try: @@ -39,15 +38,15 @@ JIRA_IMPORTED = False # Location of your Spark git development area -SPARK_HOME = os.environ.get("SPARK_HOME", "/home/patrick/Documents/spark") +SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site PR_REMOTE_NAME = os.environ.get("PR_REMOTE_NAME", "apache-github") # Remote name which points to Apache git PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache") # ASF JIRA username -JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "pwendell") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") # ASF JIRA password -JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "1234") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" @@ -129,7 +128,7 @@ def merge_pr(pr_num, target_ref): merge_message_flags = [] merge_message_flags += ["-m", title] - if body != None: + if body is not None: # We remove @ symbols from the body to avoid triggering e-mails # to people every time someone creates a public fork of Spark. merge_message_flags += ["-m", body.replace("@", "")] @@ -179,7 +178,14 @@ def cherry_pick(pr_num, merge_hash, default_branch): run_cmd("git fetch %s %s:%s" % (PUSH_REMOTE_NAME, pick_ref, pick_branch_name)) run_cmd("git checkout %s" % pick_branch_name) - run_cmd("git cherry-pick -sx %s" % merge_hash) + + try: + run_cmd("git cherry-pick -sx %s" % merge_hash) + except Exception as e: + msg = "Error cherry-picking: %s\nWould you like to manually fix-up this merge?" % e + continue_maybe(msg) + msg = "Okay, please fix any conflicts and finish the cherry-pick. Finished?" + continue_maybe(msg) continue_maybe("Pick complete (local ref %s). Push to %s?" % ( pick_branch_name, PUSH_REMOTE_NAME)) @@ -280,6 +286,7 @@ def get_version_json(version_str): pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) +pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) url = pr["url"] title = pr["title"] @@ -289,19 +296,23 @@ def get_version_json(version_str): base_ref = pr["head"]["ref"] pr_repo_desc = "%s/%s" % (user_login, base_ref) -if pr["merged"] is True: +# Merged pull requests don't appear as merged in the GitHub API; +# Instead, they're closed by asfgit. +merge_commits = \ + [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] + +if merge_commits: + merge_hash = merge_commits[0]["commit_id"] + message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] + print "Pull request %s has already been merged, assuming you want to backport" % pr_num - merge_commit_desc = run_cmd([ - 'git', 'log', '--merges', '--first-parent', - '--grep=pull request #%s' % pr_num, '--oneline']).split("\n")[0] - if merge_commit_desc == "": + commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', + "%s^{commit}" % merge_hash]).strip() != "" + if not commit_is_downloaded: fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) - merge_hash = merge_commit_desc[:7] - message = merge_commit_desc[8:] - - print "Found: %s" % message - maybe_cherry_pick(pr_num, merge_hash, latest_branch) + print "Found commit %s:\n%s" % (merge_hash, message) + cherry_pick(pr_num, merge_hash, latest_branch) sys.exit(0) if not bool(pr["mergeable"]): @@ -323,9 +334,13 @@ def get_version_json(version_str): merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: - continue_maybe("Would you like to update an associated JIRA?") - jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) - resolve_jira(title, merged_refs, jira_comment) + if JIRA_USERNAME and JIRA_PASSWORD: + continue_maybe("Would you like to update an associated JIRA?") + jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) + resolve_jira(title, merged_refs, jira_comment) + else: + print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Exiting without trying to close the associated JIRA." else: print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." print "Exiting without trying to close the associated JIRA." diff --git a/dev/mima b/dev/mima index 7857294f61caf..4c3e65039b160 100755 --- a/dev/mima +++ b/dev/mima @@ -22,7 +22,7 @@ set -e # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" -cd $FWDIR +cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update diff --git a/dev/run-tests b/dev/run-tests index 51e4def0f835a..daa85bc750c07 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -19,9 +19,24 @@ # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" -cd $FWDIR +cd "$FWDIR" -export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" +if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then + if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=1.0.4" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" + fi +fi + +if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" +fi +echo "SBT_MAVEN_PROFILES_ARGS=\"$SBT_MAVEN_PROFILES_ARGS\"" # Remove work directory rm -rf ./work @@ -37,7 +52,7 @@ JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..* # Partial solution for SPARK-1455. Only run Hive tests if there are sql changes. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master - diffs=`git diff --dirstat master | awk '{ print $2; }' | grep "^sql/"` + diffs=`git diff --name-only master | grep "^sql/"` if [ -n "$diffs" ]; then echo "Detected changes in SQL. Will run Hive test suite." export _RUN_SQL_TESTS=true # exported for PySpark tests @@ -61,16 +76,15 @@ dev/scalastyle echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" + +if [ -n "$_RUN_SQL_TESTS" ]; then + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" +fi # echo "q" is needed because sbt on encountering a build file with failure # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. -if [ -n "$_RUN_SQL_TESTS" ]; then - echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \ - assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" -else - echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \ - grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" -fi +echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS clean package assembly/assembly test | \ + grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" echo "=========================================================================" echo "Running PySpark tests" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 8dda671e976ce..3076eb847b420 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -22,7 +22,7 @@ # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" -cd $FWDIR +cd "$FWDIR" COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" diff --git a/dev/scalastyle b/dev/scalastyle index a02d06912f238..d9f2b91a3a091 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt diff --git a/docs/configuration.md b/docs/configuration.md index 2e6c85cc2bcca..2a71d7b820e5f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -414,10 +414,18 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.mb 2 - Maximum object size to allow within Kryo (the library needs to create a buffer at least as - large as the largest single object you'll serialize). Increase this if you get a "buffer limit - exceeded" exception inside Kryo. Note that there will be one buffer per core on each - worker. + Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max.mb if needed. + + + + spark.kryoserializer.buffer.max.mb + 64 + + Maximum allowable size of Kryo serialization buffer, in megabytes. This must be larger than any + object you attempt to serialize. Increase this if you get a "buffer limit exceeded" exception + inside Kryo. @@ -533,6 +541,13 @@ Apart from these, the following properties are also available, and may be useful output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + + spark.executor.heartbeatInterval + 10000 + Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let + the driver know that the executor is still alive and update it with metrics for in-progress + tasks. + #### Networking diff --git a/docs/monitoring.md b/docs/monitoring.md index 84073fe4d949a..d07ec4a57a2cc 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -33,7 +33,7 @@ application's UI after the application has finished. If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of a finished application through Spark's history server, provided that the application's event logs exist. -You can start a the history server by executing: +You can start the history server by executing: ./sbin/start-history-server.sh @@ -106,7 +106,7 @@ follows: Indicates whether the history server should use kerberos to login. This is useful if the history server is accessing HDFS files on a secure Hadoop cluster. If this is - true it looks uses the configs spark.history.kerberos.principal and + true, it uses the configs spark.history.kerberos.principal and spark.history.kerberos.keytab. diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 90c69713019f2..a88bf27add883 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -383,16 +383,16 @@ Apart from text files, Spark's Python API also supports several other data forma * `RDD.saveAsPickleFile` and `SparkContext.pickleFile` support saving an RDD in a simple format consisting of pickled Python objects. Batching is used on pickle serialization, with default batch size 10. -* Details on reading `SequenceFile` and arbitrary Hadoop `InputFormat` are given below. - -### SequenceFile and Hadoop InputFormats +* SequenceFile and Hadoop Input/Output Formats **Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on SparkSQL, in which case SparkSQL is the preferred approach. -#### Writable Support +**Writable Support** -PySpark SequenceFile support loads an RDD within Java, and pickles the resulting Java objects using -[Pyrolite](https://github.com/irmen/Pyrolite/). The following Writables are automatically converted: +PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the +resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, +PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following +Writables are automatically converted: @@ -403,32 +403,30 @@ PySpark SequenceFile support loads an RDD within Java, and pickles the resulting - - -
Writable TypePython Type
BooleanWritablebool
BytesWritablebytearray
NullWritableNone
ArrayWritablelist of primitives, or tuple of objects
MapWritabledict
Custom Class conforming to Java Bean conventionsdict of public properties (via JavaBean getters and setters) + __class__ for the class type
-#### Loading SequenceFiles +Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, +users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default +converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get +Python `array.array` for arrays of primitive types, users need to specify custom converters. + +**Saving and Loading SequenceFiles** -Similarly to text files, SequenceFiles can be loaded by specifying the path. The key and value +Similarly to text files, SequenceFiles can be saved and loaded by specifying the path. The key and value classes can be specified, but for standard Writables this is not required. {% highlight python %} ->>> rdd = sc.sequenceFile("path/to/sequencefile/of/doubles") ->>> rdd.collect() # this example has DoubleWritable keys and Text values -[(1.0, u'aa'), - (2.0, u'bb'), - (2.0, u'aa'), - (3.0, u'cc'), - (2.0, u'bb'), - (1.0, u'aa')] +>>> rdd = sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) +>>> rdd.saveAsSequenceFile("path/to/file") +>>> sorted(sc.sequenceFile("path/to/file").collect()) +[(1, u'a'), (2, u'aa'), (3, u'aaa')] {% endhighlight %} -#### Loading Other Hadoop InputFormats +**Saving and Loading Other Hadoop Input/Output Formats** -PySpark can also read any Hadoop InputFormat, for both 'new' and 'old' Hadoop APIs. If required, -a Hadoop configuration can be passed in as a Python dict. Here is an example using the +PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. +If required, a Hadoop configuration can be passed in as a Python dict. Here is an example using the Elasticsearch ESInputFormat: {% highlight python %} @@ -447,8 +445,7 @@ Note that, if the InputFormat simply depends on a Hadoop configuration and/or in the key and value classes can easily be converted according to the above table, then this approach should work well for such cases. -If you have custom serialized binary data (such as loading data from Cassandra / HBase) or custom -classes that don't conform to the JavaBean requirements, then you will first need to +If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler. A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided for this. Simply extend this trait and implement your transformation code in the ```convert``` @@ -456,11 +453,8 @@ method. Remember to ensure that this class, along with any dependencies required classpath. See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and -the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/pythonconverters) -for examples of using HBase and Cassandra ```InputFormat```. - -Future support for writing data out as ```SequenceFileOutputFormat``` and other ```OutputFormats```, -is forthcoming. +the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) +for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index ad8b6c0e51a78..2fb30765f35e8 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -242,9 +242,6 @@ To run an interactive Spark shell against the cluster, run the following command ./bin/spark-shell --master spark://IP:PORT -Note that if you are running spark-shell from one of the spark cluster machines, the `bin/spark-shell` script will -automatically set MASTER from the `SPARK_MASTER_IP` and `SPARK_MASTER_PORT` variables in `conf/spark-env.sh`. - You can also pass an option `--cores ` to control the number of cores that spark-shell uses on the cluster. # Launching Compiled Spark Applications diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 38728534a46e0..7261badd411a9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -136,7 +136,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.createSchemaRDD // Define the schema using a case class. -// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, +// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, // you can use custom classes that implement the Product interface. case class Person(name: String, age: Int) @@ -487,9 +487,9 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can also experiment with the `LocalHiveContext`, -which is similar to `HiveContext`, but creates a local copy of the `metastore` and `warehouse` -automatically. +not have an existing Hive deployment can still create a HiveContext. When not configured by the +hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current +directory. {% highlight scala %} // sc is an existing SparkContext. @@ -548,7 +548,6 @@ results = hiveContext.hql("FROM src SELECT key, value").collect() - # Writing Language-Integrated Relational Queries **Language-Integrated queries are currently only supported in Scala.** @@ -573,4 +572,210 @@ prefixed with a tick (`'`). Implicit conversions turn these symbols into expres evaluated by the SQL execution engine. A full list of the functions supported can be found in the [ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). - \ No newline at end of file + + +## Running the Thrift JDBC server + +The Thrift JDBC server implemented here corresponds to the [`HiveServer2`] +(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test +the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive +you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver` +for maven). + +To start the JDBC server, run the following in the Spark directory: + + ./sbin/start-thriftserver.sh + +The default port the server listens on is 10000. To listen on customized host and port, please set +the `HIVE_SERVER2_THRIFT_PORT` and `HIVE_SERVER2_THRIFT_BIND_HOST` environment variables. You may +run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. Now you can +use beeline to test the Thrift JDBC server: + + ./bin/beeline + +Connect to the JDBC server in beeline with: + + beeline> !connect jdbc:hive2://localhost:10000 + +Beeline will ask you for a username and password. In non-secure mode, simply enter the username on +your machine and a blank password. For secure mode, please follow the instructions given in the +[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients) + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. + +You may also use the beeline script comes with Hive. + +### Migration Guide for Shark Users + +#### Reducer number + +In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark +SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value +is 200. Users may customize this property via `SET`: + +``` +SET spark.sql.shuffle.partitions=10; +SELECT page, count(*) c FROM logs_last_month_cached +GROUP BY page ORDER BY c DESC LIMIT 10; +``` + +You may also put this property in `hive-site.xml` to override the default value. + +For now, the `mapred.reduce.tasks` property is still recognized, and is converted to +`spark.sql.shuffle.partitions` automatically. + +#### Caching + +The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no +longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to +let user control table caching explicitly: + +``` +CACHE TABLE logs_last_month; +UNCACHE TABLE logs_last_month; +``` + +**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary", +but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be +cached, you may simply count the table immediately after executing `CACHE TABLE`: + +``` +CACHE TABLE logs_last_month; +SELECT COUNT(1) FROM logs_last_month; +``` + +Several caching related features are not supported yet: + +* User defined partition level cache eviction policy +* RDD reloading +* In-memory cache write through policy + +### Compatibility with Apache Hive + +#### Deploying in Exising Hive Warehouses + +Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive +installations. You do not need to modify your existing Hive Metastore or change the data placement +or partitioning of your tables. + +#### Supported Hive Features + +Spark SQL supports the vast majority of Hive features, such as: + +* Hive query statements, including: + * `SELECT` + * `GROUP BY + * `ORDER BY` + * `CLUSTER BY` + * `SORT BY` +* All Hive operators, including: + * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc) + * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc) + * Logical operators (`AND`, `&&`, `OR`, `||`, etc) + * Complex type constructors + * Mathemtatical functions (`sign`, `ln`, `cos`, etc) + * String functions (`instr`, `length`, `printf`, etc) +* User defined functions (UDF) +* User defined aggregation functions (UDAF) +* User defined serialization formats (SerDe's) +* Joins + * `JOIN` + * `{LEFT|RIGHT|FULL} OUTER JOIN` + * `LEFT SEMI JOIN` + * `CROSS JOIN` +* Unions +* Sub queries + * `SELECT col FROM ( SELECT a + b AS col from t1) t2` +* Sampling +* Explain +* Partitioned tables +* All Hive DDL Functions, including: + * `CREATE TABLE` + * `CREATE TABLE AS SELECT` + * `ALTER TABLE` +* Most Hive Data types, including: + * `TINYINT` + * `SMALLINT` + * `INT` + * `BIGINT` + * `BOOLEAN` + * `FLOAT` + * `DOUBLE` + * `STRING` + * `BINARY` + * `TIMESTAMP` + * `ARRAY<>` + * `MAP<>` + * `STRUCT<>` + +#### Unsupported Hive Functionality + +Below is a list of Hive features that we don't support yet. Most of these features are rarely used +in Hive deployments. + +**Major Hive Features** + +* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL + doesn't support buckets yet. + +**Esoteric Hive Features** + +* Tables with partitions using different input formats: In Spark SQL, all table partitions need to + have the same input format. +* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions + (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. +* `UNIONTYPE` +* Unique join +* Single query multi insert +* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at + the moment. + +**Hive Input/Output Formats** + +* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat. +* Hadoop archive + +**Hive Optimizations** + +A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are +not necessary due to Spark SQL's in-memory computational model. Others are slotted for future +releases of Spark SQL. + +* Block level bitmap indexes and virtual columns (used to build indexes) +* Automatically convert a join to map join: For joining a large table with multiple small tables, + Hive automatically converts the join into a map join. We are adding this auto conversion in the + next release. +* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you + need to control the degree of parallelism post-shuffle using "SET + spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the + next release. +* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still + launches tasks to compute the result. +* Skew data flag: Spark SQL does not follow the skew data flags in Hive. +* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. +* Merge multiple small files for query results: if the result output contains multiple small files, + Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS + metadata. Spark SQL does not support that. + +## Running the Spark SQL CLI + +The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute +queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server. + +To start the Spark SQL CLI, run the following in the Spark directory: + + ./bin/spark-sql + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +You may run `./bin/spark-sql --help` for a complete list of all available +options. + +# Cached tables + +Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. +Then Spark SQL will scan only required columns and will automatically tune compression to minimize +memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. + +Note that if you just call `cache` rather than `cacheTable`, tables will _not_ be cached in +in-memory columnar format. So we strongly recommend using `cacheTable` whenever you want to +cache tables. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 90a0eef60c200..7b8b7933434c4 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -939,7 +939,7 @@ Receiving multiple data streams can therefore be achieved by creating multiple i and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input stream receiving two topics of data can be split into two Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to received in parallel, and increasing overall throughput. +thus allowing data to be received in parallel, and increasing overall throughput. Another parameter that should be considered is the receiver's blocking interval. For most receivers, the received data is coalesced together into large blocks of data before storing inside Spark's memory. @@ -980,7 +980,7 @@ If the number of tasks launched per second is high (say, 50 or more per second), of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* **Task Serialization**: Using Kryo serialization for serializing tasks can reduced the task +* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to diff --git a/examples/pom.xml b/examples/pom.xml index bd1c387c2eb91..c4ed0f5a6a02b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-examples_2.10 - examples + examples jar Spark Project Examples diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py new file mode 100644 index 0000000000000..1dfbf98604425 --- /dev/null +++ b/examples/src/main/python/cassandra_outputformat.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Create data in Cassandra fist +(following: https://wiki.apache.org/cassandra/GettingStarted) + +cqlsh> CREATE KEYSPACE test + ... WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; +cqlsh> use test; +cqlsh:test> CREATE TABLE users ( + ... user_id int PRIMARY KEY, + ... fname text, + ... lname text + ... ); + +> cassandra_outputformat test users 1745 john smith +> cassandra_outputformat test users 1744 john doe +> cassandra_outputformat test users 1746 john smith + +cqlsh:test> SELECT * FROM users; + + user_id | fname | lname +---------+-------+------- + 1745 | john | smith + 1744 | john | doe + 1746 | john | smith +""" +if __name__ == "__main__": + if len(sys.argv) != 7: + print >> sys.stderr, """ + Usage: cassandra_outputformat + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_outputformat.py + Assumes you have created the following table in Cassandra already, + running on , in . + + cqlsh:> CREATE TABLE ( + ... user_id int PRIMARY KEY, + ... fname text, + ... lname text + ... ); + """ + exit(-1) + + host = sys.argv[1] + keyspace = sys.argv[2] + cf = sys.argv[3] + sc = SparkContext(appName="CassandraOutputFormat") + + conf = {"cassandra.output.thrift.address":host, + "cassandra.output.thrift.port":"9160", + "cassandra.output.keyspace":keyspace, + "cassandra.output.partitioner.class":"Murmur3Partitioner", + "cassandra.output.cql":"UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?", + "mapreduce.output.basename":cf, + "mapreduce.outputformat.class":"org.apache.cassandra.hadoop.cql3.CqlOutputFormat", + "mapreduce.job.output.key.class":"java.util.Map", + "mapreduce.job.output.value.class":"java.util.List"} + key = {"user_id" : int(sys.argv[4])} + sc.parallelize([(key, sys.argv[5:])]).saveAsNewAPIHadoopDataset( + conf=conf, + keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter", + valueConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLValueConverter") diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 3289d9880a0f5..c9fa8e171c2a1 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -65,7 +65,8 @@ "org.apache.hadoop.hbase.mapreduce.TableInputFormat", "org.apache.hadoop.hbase.io.ImmutableBytesWritable", "org.apache.hadoop.hbase.client.Result", - valueConverter="org.apache.spark.examples.pythonconverters.HBaseConverter", + keyConverter="org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter", + valueConverter="org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter", conf=conf) output = hbase_rdd.collect() for (k, v) in output: diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py new file mode 100644 index 0000000000000..5e11548fd13f7 --- /dev/null +++ b/examples/src/main/python/hbase_outputformat.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Create test table in HBase first: + +hbase(main):001:0> create 'test', 'f1' +0 row(s) in 0.7840 seconds + +> hbase_outputformat test row1 f1 q1 value1 +> hbase_outputformat test row2 f1 q1 value2 +> hbase_outputformat test row3 f1 q1 value3 +> hbase_outputformat test row4 f1 q1 value4 + +hbase(main):002:0> scan 'test' +ROW COLUMN+CELL + row1 column=f1:q1, timestamp=1405659615726, value=value1 + row2 column=f1:q1, timestamp=1405659626803, value=value2 + row3 column=f1:q1, timestamp=1405659640106, value=value3 + row4 column=f1:q1, timestamp=1405659650292, value=value4 +4 row(s) in 0.0780 seconds +""" +if __name__ == "__main__": + if len(sys.argv) != 7: + print >> sys.stderr, """ + Usage: hbase_outputformat + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_outputformat.py + Assumes you have created
with column family in HBase running on already + """ + exit(-1) + + host = sys.argv[1] + table = sys.argv[2] + sc = SparkContext(appName="HBaseOutputFormat") + + conf = {"hbase.zookeeper.quorum": host, + "hbase.mapred.outputtable": table, + "mapreduce.outputformat.class" : "org.apache.hadoop.hbase.mapreduce.TableOutputFormat", + "mapreduce.job.output.key.class" : "org.apache.hadoop.hbase.io.ImmutableBytesWritable", + "mapreduce.job.output.value.class" : "org.apache.hadoop.io.Writable"} + + sc.parallelize([sys.argv[3:]]).map(lambda x: (x[0], x)).saveAsNewAPIHadoopDataset( + conf=conf, + keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter", + valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 43f13fe24f0d0..cf3d2cca81ff6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,7 +21,6 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} @@ -33,9 +32,12 @@ import org.apache.spark.rdd.RDD /** * An example runner for decision tree. Run with * {{{ - * ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] + * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. */ object DecisionTreeRunner { @@ -48,11 +50,12 @@ object DecisionTreeRunner { case class Params( input: String = null, + dataFormat: String = "libsvm", algo: Algo = Classification, - numClassesForClassification: Int = 2, - maxDepth: Int = 5, + maxDepth: Int = 4, impurity: ImpurityType = Gini, - maxBins: Int = 100) + maxBins: Int = 100, + fracTest: Double = 0.2) def main(args: Array[String]) { val defaultParams = Params() @@ -69,25 +72,31 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("numClassesForClassification") - .text(s"number of classes for classification, " - + s"default: ${defaultParams.numClassesForClassification}") - .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) arg[String]("") .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") .required() .action((x, c) => c.copy(input = x)) checkConfig { params => - if (params.algo == Classification && - (params.impurity == Gini || params.impurity == Entropy)) { - success - } else if (params.algo == Regression && params.impurity == Variance) { - success + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") } else { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + if (params.algo == Classification && + (params.impurity == Gini || params.impurity == Entropy)) { + success + } else if (params.algo == Regression && params.impurity == Variance) { + success + } else { + failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + } } } } @@ -100,16 +109,57 @@ object DecisionTreeRunner { } def run(params: Params) { + val conf = new SparkConf().setAppName("DecisionTreeRunner") val sc = new SparkContext(conf) // Load training data and cache it. - val examples = MLUtils.loadLabeledPoints(sc, params.input).cache() + val origExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + } + // For classification, re-index classes if needed. + val (examples, numClasses) = params.algo match { + case Classification => { + // classCounts: class --> # examples in class + val classCounts = origExamples.map(_.label).countByValue() + val sortedClasses = classCounts.keys.toList.sorted + val numClasses = classCounts.size + // classIndexMap: class --> index in 0,...,numClasses-1 + val classIndexMap = { + if (classCounts.keySet != Set(0.0, 1.0)) { + sortedClasses.zipWithIndex.toMap + } else { + Map[Double, Int]() + } + } + val examples = { + if (classIndexMap.isEmpty) { + origExamples + } else { + origExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features)) + } + } + val numExamples = examples.count() + println(s"numClasses = $numClasses.") + println(s"Per-class example fractions, counts:") + println(s"Class\tFrac\tCount") + sortedClasses.foreach { c => + val frac = classCounts(c) / numExamples.toDouble + println(s"$c\t$frac\t${classCounts(c)}") + } + (examples, numClasses) + } + case Regression => + (origExamples, 0) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } - val splits = examples.randomSplit(Array(0.8, 0.2)) + // Split into training, test. + val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) val training = splits(0).cache() val test = splits(1).cache() - val numTraining = training.count() val numTest = test.count() @@ -129,17 +179,19 @@ object DecisionTreeRunner { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = params.numClassesForClassification) + numClassesForClassification = numClasses) val model = DecisionTree.train(training, strategy) + println(model) + if (params.algo == Classification) { val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") + println(s"Test accuracy = $accuracy") } if (params.algo == Regression) { val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") + println(s"Test mean squared error = $mse") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 4811bb70e4b28..05b7d66f8dffd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -91,7 +91,7 @@ object LinearRegression extends App { Logger.getRootLogger.setLevel(Level.WARN) - val examples = MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() val splits = examples.randomSplit(Array(0.8, 0.2)) val training = splits(0).cache() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 537e68a0991aa..88acd9dbb0878 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -22,7 +22,7 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.classification.NaiveBayes -import org.apache.spark.mllib.util.{MLUtils, MulticlassLabelParser} +import org.apache.spark.mllib.util.MLUtils /** * An example naive Bayes app. Run with @@ -76,7 +76,7 @@ object SparseNaiveBayes { if (params.minPartitions > 0) params.minPartitions else sc.defaultMinPartitions val examples = - MLUtils.loadLibSVMFile(sc, params.input, multiclass = true, params.numFeatures, minPartitions) + MLUtils.loadLibSVMFile(sc, params.input, params.numFeatures, minPartitions) // Cache examples because it will be used in both training and evaluation. examples.cache() diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala index 29a65c7a5f295..83feb5703b908 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples.pythonconverters import org.apache.spark.api.python.Converter import java.nio.ByteBuffer import org.apache.cassandra.utils.ByteBufferUtil -import collection.JavaConversions.{mapAsJavaMap, mapAsScalaMap} +import collection.JavaConversions._ /** @@ -44,3 +44,25 @@ class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, St mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) } } + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * Map[String, Int] to Cassandra key + */ +class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, ByteBuffer]] { + override def convert(obj: Any): java.util.Map[String, ByteBuffer] = { + val input = obj.asInstanceOf[java.util.Map[String, Int]] + mapAsJavaMap(input.mapValues(i => ByteBufferUtil.bytes(i))) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * List[String] to Cassandra value + */ +class ToCassandraCQLValueConverter extends Converter[Any, java.util.List[ByteBuffer]] { + override def convert(obj: Any): java.util.List[ByteBuffer] = { + val input = obj.asInstanceOf[java.util.List[String]] + seqAsJavaList(input.map(s => ByteBufferUtil.bytes(s))) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala new file mode 100644 index 0000000000000..273bee0a8b30f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.pythonconverters + +import scala.collection.JavaConversions._ + +import org.apache.spark.api.python.Converter +import org.apache.hadoop.hbase.client.{Put, Result} +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.hadoop.hbase.util.Bytes + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts an + * HBase Result to a String + */ +class HBaseResultToStringConverter extends Converter[Any, String] { + override def convert(obj: Any): String = { + val result = obj.asInstanceOf[Result] + Bytes.toStringBinary(result.value()) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts an + * ImmutableBytesWritable to a String + */ +class ImmutableBytesWritableToStringConverter extends Converter[Any, String] { + override def convert(obj: Any): String = { + val key = obj.asInstanceOf[ImmutableBytesWritable] + Bytes.toStringBinary(key.get()) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * String to an ImmutableBytesWritable + */ +class StringToImmutableBytesWritableConverter extends Converter[Any, ImmutableBytesWritable] { + override def convert(obj: Any): ImmutableBytesWritable = { + val bytes = Bytes.toBytes(obj.asInstanceOf[String]) + new ImmutableBytesWritable(bytes) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * list of Strings to HBase Put + */ +class StringListToPutConverter extends Converter[Any, Put] { + override def convert(obj: Any): Put = { + val output = obj.asInstanceOf[java.util.ArrayList[String]].map(Bytes.toBytes(_)).toArray + val put = new Put(output(0)) + put.add(output(1), output(2), output(3)) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 66a23fac39999..dc5290fb4f10e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.sql.hive import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ -import org.apache.spark.sql.hive.LocalHiveContext +import org.apache.spark.sql.hive.HiveContext object HiveFromSpark { case class Record(key: Int, value: String) @@ -31,7 +31,7 @@ object HiveFromSpark { // A local hive context creates an instance of the Hive Metastore in process, storing the // the warehouse data in the current directory. This location can be overridden by // specifying a second parameter to the constructor. - val hiveContext = new LocalHiveContext(sc) + val hiveContext = new HiveContext(sc) import hiveContext._ hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala new file mode 100644 index 0000000000000..1cc8c8d5c23b6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.flume._ +import org.apache.spark.util.IntParam +import java.net.InetSocketAddress + +/** + * Produces a count of events received from Flume. + * + * This should be used in conjunction with the Spark Sink running in a Flume agent. See + * the Spark Streaming programming guide for more details. + * + * Usage: FlumePollingEventCount + * `host` is the host on which the Spark Sink is running. + * `port` is the port at which the Spark Sink is listening. + * + * To run this example: + * `$ bin/run-example org.apache.spark.examples.streaming.FlumePollingEventCount [host] [port] ` + */ +object FlumePollingEventCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println( + "Usage: FlumePollingEventCount ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(host, IntParam(port)) = args + + val batchInterval = Milliseconds(2000) + + // Create the context and set the batch size + val sparkConf = new SparkConf().setAppName("FlumePollingEventCount") + val ssc = new StreamingContext(sparkConf, batchInterval) + + // Create a flume stream that polls the Spark Sink running in a Flume agent + val stream = FlumeUtils.createPollingStream(ssc, host, port) + + // Print out the count of events received from this server in each batch + stream.count().map(cnt => "Received " + cnt + " flume events." ).print() + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml new file mode 100644 index 0000000000000..d11129ce8d89d --- /dev/null +++ b/external/flume-sink/pom.xml @@ -0,0 +1,100 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + spark-streaming-flume-sink_2.10 + + streaming-flume-sink + + + jar + Spark Project External Flume Sink + http://spark.apache.org/ + + + org.apache.flume + flume-ng-sdk + 1.4.0 + + + io.netty + netty + + + org.apache.thrift + libthrift + + + + + org.apache.flume + flume-ng-core + 1.4.0 + + + io.netty + netty + + + org.apache.thrift + libthrift + + + + + org.scala-lang + scala-library + 2.10.4 + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + org.apache.avro + avro-maven-plugin + 1.7.3 + + + ${project.basedir}/target/scala-${scala.binary.version}/src_managed/main/compiled_avro + + + + generate-sources + + idl-protocol + + + + + + + diff --git a/external/flume-sink/src/main/avro/sparkflume.avdl b/external/flume-sink/src/main/avro/sparkflume.avdl new file mode 100644 index 0000000000000..8806e863ac7c6 --- /dev/null +++ b/external/flume-sink/src/main/avro/sparkflume.avdl @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +@namespace("org.apache.spark.streaming.flume.sink") + +protocol SparkFlumeProtocol { + + record SparkSinkEvent { + map headers; + bytes body; + } + + record EventBatch { + string errorMsg = ""; // If this is empty it is a valid message, else it represents an error + string sequenceNumber; + array events; + } + + EventBatch getEventBatch (int n); + + void ack (string sequenceNumber); + + void nack (string sequenceNumber); +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala new file mode 100644 index 0000000000000..17cbc6707b5ea --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import org.slf4j.{Logger, LoggerFactory} + +/** + * Copy of the org.apache.spark.Logging for being used in the Spark Sink. + * The org.apache.spark.Logging is not used so that all of Spark is not brought + * in as a dependency. + */ +private[sink] trait Logging { + // Make the log field transient so that objects with Logging can + // be serialized and used on another machine + @transient private var log_ : Logger = null + + // Method to get or create the logger for this object + protected def log: Logger = { + if (log_ == null) { + initializeIfNecessary() + var className = this.getClass.getName + // Ignore trailing $'s in the class names for Scala objects + if (className.endsWith("$")) { + className = className.substring(0, className.length - 1) + } + log_ = LoggerFactory.getLogger(className) + } + log_ + } + + // Log methods that take only a String + protected def logInfo(msg: => String) { + if (log.isInfoEnabled) log.info(msg) + } + + protected def logDebug(msg: => String) { + if (log.isDebugEnabled) log.debug(msg) + } + + protected def logTrace(msg: => String) { + if (log.isTraceEnabled) log.trace(msg) + } + + protected def logWarning(msg: => String) { + if (log.isWarnEnabled) log.warn(msg) + } + + protected def logError(msg: => String) { + if (log.isErrorEnabled) log.error(msg) + } + + // Log methods that take Throwables (Exceptions/Errors) too + protected def logInfo(msg: => String, throwable: Throwable) { + if (log.isInfoEnabled) log.info(msg, throwable) + } + + protected def logDebug(msg: => String, throwable: Throwable) { + if (log.isDebugEnabled) log.debug(msg, throwable) + } + + protected def logTrace(msg: => String, throwable: Throwable) { + if (log.isTraceEnabled) log.trace(msg, throwable) + } + + protected def logWarning(msg: => String, throwable: Throwable) { + if (log.isWarnEnabled) log.warn(msg, throwable) + } + + protected def logError(msg: => String, throwable: Throwable) { + if (log.isErrorEnabled) log.error(msg, throwable) + } + + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + + private def initializeIfNecessary() { + if (!Logging.initialized) { + Logging.initLock.synchronized { + if (!Logging.initialized) { + initializeLogging() + } + } + } + } + + private def initializeLogging() { + Logging.initialized = true + + // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html + log + } +} + +private[sink] object Logging { + @volatile private var initialized = false + val initLock = new Object() + try { + // We use reflection here to handle the case where users remove the + // slf4j-to-jul bridge order to route their logs to JUL. + val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) + val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] + if (!installed) { + bridgeClass.getMethod("install").invoke(null) + } + } catch { + case e: ClassNotFoundException => // can't log anything yet so just fail silently + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala new file mode 100644 index 0000000000000..7da8eb3e35912 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.util.concurrent.{ConcurrentHashMap, Executors} +import java.util.concurrent.atomic.AtomicLong + +import org.apache.flume.Channel +import org.apache.commons.lang.RandomStringUtils +import com.google.common.util.concurrent.ThreadFactoryBuilder + +/** + * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process + * requests. Each getEvents, ack and nack call is forwarded to an instance of this class. + * @param threads Number of threads to use to process requests. + * @param channel The channel that the sink pulls events from + * @param transactionTimeout Timeout in millis after which the transaction if not acked by Spark + * is rolled back. + */ +// Flume forces transactions to be thread-local. So each transaction *must* be committed, or +// rolled back from the thread it was originally created in. So each getEvents call from Spark +// creates a TransactionProcessor which runs in a new thread, in which the transaction is created +// and events are pulled off the channel. Once the events are sent to spark, +// that thread is blocked and the TransactionProcessor is saved in a map, +// until an ACK or NACK comes back or the transaction times out (after the specified timeout). +// When the response comes or a timeout is hit, the TransactionProcessor is retrieved and then +// unblocked, at which point the transaction is committed or rolled back. + +private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, + val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { + val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, + new ThreadFactoryBuilder().setDaemon(true) + .setNameFormat("Spark Sink Processor Thread - %d").build())) + private val processorMap = new ConcurrentHashMap[CharSequence, TransactionProcessor]() + // This sink will not persist sequence numbers and reuses them if it gets restarted. + // So it is possible to commit a transaction which may have been meant for the sink before the + // restart. + // Since the new txn may not have the same sequence number we must guard against accidentally + // committing a new transaction. To reduce the probability of that happening a random string is + // prepended to the sequence number. Does not change for life of sink + private val seqBase = RandomStringUtils.randomAlphanumeric(8) + private val seqCounter = new AtomicLong(0) + + /** + * Returns a bunch of events to Spark over Avro RPC. + * @param n Maximum number of events to return in a batch + * @return [[EventBatch]] instance that has a sequence number and an array of at most n events + */ + override def getEventBatch(n: Int): EventBatch = { + logDebug("Got getEventBatch call from Spark.") + val sequenceNumber = seqBase + seqCounter.incrementAndGet() + val processor = new TransactionProcessor(channel, sequenceNumber, + n, transactionTimeout, backOffInterval, this) + transactionExecutorOpt.foreach(executor => { + executor.submit(processor) + }) + // Wait until a batch is available - will be an error if error message is non-empty + val batch = processor.getEventBatch + if (!SparkSinkUtils.isErrorBatch(batch)) { + processorMap.put(sequenceNumber.toString, processor) + logDebug("Sending event batch with sequence number: " + sequenceNumber) + } + batch + } + + /** + * Called by Spark to indicate successful commit of a batch + * @param sequenceNumber The sequence number of the event batch that was successful + */ + override def ack(sequenceNumber: CharSequence): Void = { + logDebug("Received Ack for batch with sequence number: " + sequenceNumber) + completeTransaction(sequenceNumber, success = true) + null + } + + /** + * Called by Spark to indicate failed commit of a batch + * @param sequenceNumber The sequence number of the event batch that failed + * @return + */ + override def nack(sequenceNumber: CharSequence): Void = { + completeTransaction(sequenceNumber, success = false) + logInfo("Spark failed to commit transaction. Will reattempt events.") + null + } + + /** + * Helper method to commit or rollback a transaction. + * @param sequenceNumber The sequence number of the batch that was completed + * @param success Whether the batch was successful or not. + */ + private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { + Option(removeAndGetProcessor(sequenceNumber)).foreach(processor => { + processor.batchProcessed(success) + }) + } + + /** + * Helper method to remove the TxnProcessor for a Sequence Number. Can be used to avoid a leak. + * @param sequenceNumber + * @return The transaction processor for the corresponding batch. Note that this instance is no + * longer tracked and the caller is responsible for that txn processor. + */ + private[sink] def removeAndGetProcessor(sequenceNumber: CharSequence): TransactionProcessor = { + processorMap.remove(sequenceNumber.toString) // The toString is required! + } + + /** + * Shuts down the executor used to process transactions. + */ + def shutdown() { + logInfo("Shutting down Spark Avro Callback Handler") + transactionExecutorOpt.foreach(executor => { + executor.shutdownNow() + }) + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala new file mode 100644 index 0000000000000..7b735133e3d14 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.net.InetSocketAddress +import java.util.concurrent._ + +import org.apache.avro.ipc.NettyServer +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.flume.Context +import org.apache.flume.Sink.Status +import org.apache.flume.conf.{Configurable, ConfigurationException} +import org.apache.flume.sink.AbstractSink + +/** + * A sink that uses Avro RPC to run a server that can be polled by Spark's + * FlumePollingInputDStream. This sink has the following configuration parameters: + * + * hostname - The hostname to bind to. Default: 0.0.0.0 + * port - The port to bind to. (No default - mandatory) + * timeout - Time in seconds after which a transaction is rolled back, + * if an ACK is not received from Spark within that time + * threads - Number of threads to use to receive requests from Spark (Default: 10) + * + * This sink is unlike other Flume sinks in the sense that it does not push data, + * instead the process method in this sink simply blocks the SinkRunner the first time it is + * called. This sink starts up an Avro IPC server that uses the SparkFlumeProtocol. + * + * Each time a getEventBatch call comes, creates a transaction and reads events + * from the channel. When enough events are read, the events are sent to the Spark receiver and + * the thread itself is blocked and a reference to it saved off. + * + * When the ack for that batch is received, + * the thread which created the transaction is is retrieved and it commits the transaction with the + * channel from the same thread it was originally created in (since Flume transactions are + * thread local). If a nack is received instead, the sink rolls back the transaction. If no ack + * is received within the specified timeout, the transaction is rolled back too. If an ack comes + * after that, it is simply ignored and the events get re-sent. + * + */ + +private[flume] +class SparkSink extends AbstractSink with Logging with Configurable { + + // Size of the pool to use for holding transaction processors. + private var poolSize: Integer = SparkSinkConfig.DEFAULT_THREADS + + // Timeout for each transaction. If spark does not respond in this much time, + // rollback the transaction + private var transactionTimeout = SparkSinkConfig.DEFAULT_TRANSACTION_TIMEOUT + + // Address info to bind on + private var hostname: String = SparkSinkConfig.DEFAULT_HOSTNAME + private var port: Int = 0 + + private var backOffInterval: Int = 200 + + // Handle to the server + private var serverOpt: Option[NettyServer] = None + + // The handler that handles the callback from Avro + private var handler: Option[SparkAvroCallbackHandler] = None + + // Latch that blocks off the Flume framework from wasting 1 thread. + private val blockingLatch = new CountDownLatch(1) + + override def start() { + logInfo("Starting Spark Sink: " + getName + " on port: " + port + " and interface: " + + hostname + " with " + "pool size: " + poolSize + " and transaction timeout: " + + transactionTimeout + ".") + handler = Option(new SparkAvroCallbackHandler(poolSize, getChannel, transactionTimeout, + backOffInterval)) + val responder = new SpecificResponder(classOf[SparkFlumeProtocol], handler.get) + // Using the constructor that takes specific thread-pools requires bringing in netty + // dependencies which are being excluded in the build. In practice, + // Netty dependencies are already available on the JVM as Flume would have pulled them in. + serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) + serverOpt.foreach(server => { + logInfo("Starting Avro server for sink: " + getName) + server.start() + }) + super.start() + } + + override def stop() { + logInfo("Stopping Spark Sink: " + getName) + handler.foreach(callbackHandler => { + callbackHandler.shutdown() + }) + serverOpt.foreach(server => { + logInfo("Stopping Avro Server for sink: " + getName) + server.close() + server.join() + }) + blockingLatch.countDown() + super.stop() + } + + override def configure(ctx: Context) { + import SparkSinkConfig._ + hostname = ctx.getString(CONF_HOSTNAME, DEFAULT_HOSTNAME) + port = Option(ctx.getInteger(CONF_PORT)). + getOrElse(throw new ConfigurationException("The port to bind to must be specified")) + poolSize = ctx.getInteger(THREADS, DEFAULT_THREADS) + transactionTimeout = ctx.getInteger(CONF_TRANSACTION_TIMEOUT, DEFAULT_TRANSACTION_TIMEOUT) + backOffInterval = ctx.getInteger(CONF_BACKOFF_INTERVAL, DEFAULT_BACKOFF_INTERVAL) + logInfo("Configured Spark Sink with hostname: " + hostname + ", port: " + port + ", " + + "poolSize: " + poolSize + ", transactionTimeout: " + transactionTimeout + ", " + + "backoffInterval: " + backOffInterval) + } + + override def process(): Status = { + // This method is called in a loop by the Flume framework - block it until the sink is + // stopped to save CPU resources. The sink runner will interrupt this thread when the sink is + // being shut down. + logInfo("Blocking Sink Runner, sink will continue to run..") + blockingLatch.await() + Status.BACKOFF + } +} + +/** + * Configuration parameters and their defaults. + */ +private[flume] +object SparkSinkConfig { + val THREADS = "threads" + val DEFAULT_THREADS = 10 + + val CONF_TRANSACTION_TIMEOUT = "timeout" + val DEFAULT_TRANSACTION_TIMEOUT = 60 + + val CONF_HOSTNAME = "hostname" + val DEFAULT_HOSTNAME = "0.0.0.0" + + val CONF_PORT = "port" + + val CONF_BACKOFF_INTERVAL = "backoffInterval" + val DEFAULT_BACKOFF_INTERVAL = 200 +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala similarity index 52% rename from mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala rename to external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala index ac85677f2f014..47c0e294d6b52 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala @@ -14,28 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.streaming.flume.sink -package org.apache.spark.mllib.util - -import org.scalatest.FunSuite - -class LabelParsersSuite extends FunSuite { - test("binary label parser") { - for (parser <- Seq(BinaryLabelParser, BinaryLabelParser.getInstance())) { - assert(parser.parse("+1") === 1.0) - assert(parser.parse("1") === 1.0) - assert(parser.parse("0") === 0.0) - assert(parser.parse("-1") === 0.0) - } - } - - test("multiclass label parser") { - for (parser <- Seq(MulticlassLabelParser, MulticlassLabelParser.getInstance())) { - assert(parser.parse("0") == 0.0) - assert(parser.parse("+1") === 1.0) - assert(parser.parse("1") === 1.0) - assert(parser.parse("2") === 2.0) - assert(parser.parse("3") === 3.0) - } +private[flume] object SparkSinkUtils { + /** + * This method determines if this batch represents an error or not. + * @param batch - The batch to check + * @return - true if the batch represents an error + */ + def isErrorBatch(batch: EventBatch): Boolean = { + !batch.getErrorMsg.toString.equals("") // If there is an error message, it is an error batch. } } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala new file mode 100644 index 0000000000000..b9e3c786ebb3b --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.nio.ByteBuffer +import java.util +import java.util.concurrent.{Callable, CountDownLatch, TimeUnit} + +import scala.util.control.Breaks + +import org.apache.flume.{Transaction, Channel} + +// Flume forces transactions to be thread-local (horrible, I know!) +// So the sink basically spawns a new thread to pull the events out within a transaction. +// The thread fills in the event batch object that is set before the thread is scheduled. +// After filling it in, the thread waits on a condition - which is released only +// when the success message comes back for the specific sequence number for that event batch. +/** + * This class represents a transaction on the Flume channel. This class runs a separate thread + * which owns the transaction. The thread is blocked until the success call for that transaction + * comes back with an ACK or NACK. + * @param channel The channel from which to pull events + * @param seqNum The sequence number to use for the transaction. Must be unique + * @param maxBatchSize The maximum number of events to process per batch + * @param transactionTimeout Time in seconds after which a transaction must be rolled back + * without waiting for an ACK from Spark + * @param parent The parent [[SparkAvroCallbackHandler]] instance, for reporting timeouts + */ +private class TransactionProcessor(val channel: Channel, val seqNum: String, + var maxBatchSize: Int, val transactionTimeout: Int, val backOffInterval: Int, + val parent: SparkAvroCallbackHandler) extends Callable[Void] with Logging { + + // If a real batch is not returned, we always have to return an error batch. + @volatile private var eventBatch: EventBatch = new EventBatch("Unknown Error", "", + util.Collections.emptyList()) + + // Synchronization primitives + val batchGeneratedLatch = new CountDownLatch(1) + val batchAckLatch = new CountDownLatch(1) + + // Sanity check to ensure we don't loop like crazy + val totalAttemptsToRemoveFromChannel = Int.MaxValue / 2 + + // OK to use volatile, since the change would only make this true (otherwise it will be + // changed to false - we never apply a negation operation to this) - which means the transaction + // succeeded. + @volatile private var batchSuccess = false + + // The transaction that this processor would handle + var txOpt: Option[Transaction] = None + + /** + * Get an event batch from the channel. This method will block until a batch of events is + * available from the channel. If no events are available after a large number of attempts of + * polling the channel, this method will return an [[EventBatch]] with a non-empty error message + * + * @return An [[EventBatch]] instance with sequence number set to seqNum, filled with a + * maximum of maxBatchSize events + */ + def getEventBatch: EventBatch = { + batchGeneratedLatch.await() + eventBatch + } + + /** + * This method is to be called by the sink when it receives an ACK or NACK from Spark. This + * method is a no-op if it is called after transactionTimeout has expired since + * getEventBatch returned a batch of events. + * @param success True if an ACK was received and the transaction should be committed, else false. + */ + def batchProcessed(success: Boolean) { + logDebug("Batch processed for sequence number: " + seqNum) + batchSuccess = success + batchAckLatch.countDown() + } + + /** + * Populates events into the event batch. If the batch cannot be populated, + * this method will not set the events into the event batch, but it sets an error message. + */ + private def populateEvents() { + try { + txOpt = Option(channel.getTransaction) + if(txOpt.isEmpty) { + eventBatch.setErrorMsg("Something went wrong. Channel was " + + "unable to create a transaction!") + } + txOpt.foreach(tx => { + tx.begin() + val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) + val loop = new Breaks + var gotEventsInThisTxn = false + var loopCounter: Int = 0 + loop.breakable { + while (events.size() < maxBatchSize + && loopCounter < totalAttemptsToRemoveFromChannel) { + loopCounter += 1 + Option(channel.take()) match { + case Some(event) => + events.add(new SparkSinkEvent(toCharSequenceMap(event.getHeaders), + ByteBuffer.wrap(event.getBody))) + gotEventsInThisTxn = true + case None => + if (!gotEventsInThisTxn) { + logDebug("Sleeping for " + backOffInterval + " millis as no events were read in" + + " the current transaction") + TimeUnit.MILLISECONDS.sleep(backOffInterval) + } else { + loop.break() + } + } + } + } + if (!gotEventsInThisTxn) { + val msg = "Tried several times, " + + "but did not get any events from the channel!" + logWarning(msg) + eventBatch.setErrorMsg(msg) + } else { + // At this point, the events are available, so fill them into the event batch + eventBatch = new EventBatch("",seqNum, events) + } + }) + } catch { + case e: Exception => + logWarning("Error while processing transaction.", e) + eventBatch.setErrorMsg(e.getMessage) + try { + txOpt.foreach(tx => { + rollbackAndClose(tx, close = true) + }) + } finally { + txOpt = None + } + } finally { + batchGeneratedLatch.countDown() + } + } + + /** + * Waits for upto transactionTimeout seconds for an ACK. If an ACK comes in + * this method commits the transaction with the channel. If the ACK does not come in within + * that time or a NACK comes in, this method rolls back the transaction. + */ + private def processAckOrNack() { + batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) + txOpt.foreach(tx => { + if (batchSuccess) { + try { + logDebug("Committing transaction") + tx.commit() + } catch { + case e: Exception => + logWarning("Error while attempting to commit transaction. Transaction will be rolled " + + "back", e) + rollbackAndClose(tx, close = false) // tx will be closed later anyway + } finally { + tx.close() + } + } else { + logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.") + rollbackAndClose(tx, close = true) + // This might have been due to timeout or a NACK. Either way the following call does not + // cause issues. This is required to ensure the TransactionProcessor instance is not leaked + parent.removeAndGetProcessor(seqNum) + } + }) + } + + /** + * Helper method to rollback and optionally close a transaction + * @param tx The transaction to rollback + * @param close Whether the transaction should be closed or not after rolling back + */ + private def rollbackAndClose(tx: Transaction, close: Boolean) { + try { + logWarning("Spark was unable to successfully process the events. Transaction is being " + + "rolled back.") + tx.rollback() + } catch { + case e: Exception => + logError("Error rolling back transaction. Rollback may have failed!", e) + } finally { + if (close) { + tx.close() + } + } + } + + /** + * Helper method to convert a Map[String, String] to Map[CharSequence, CharSequence] + * @param inMap The map to be converted + * @return The converted map + */ + private def toCharSequenceMap(inMap: java.util.Map[String, String]): java.util.Map[CharSequence, + CharSequence] = { + val charSeqMap = new util.HashMap[CharSequence, CharSequence](inMap.size()) + charSeqMap.putAll(inMap) + charSeqMap + } + + /** + * When the thread is started it sets as many events as the batch size or less (if enough + * events aren't available) into the eventBatch and object and lets any threads waiting on the + * [[getEventBatch]] method to proceed. Then this thread waits for acks or nacks to come in, + * or for a specified timeout and commits or rolls back the transaction. + * @return + */ + override def call(): Void = { + populateEvents() + processAckOrNack() + null + } +} diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 61a6aff543aed..c532705f3950c 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-flume_2.10 - streaming-flume + streaming-flume jar Spark Project External Flume @@ -72,11 +72,21 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface test + + org.apache.spark + spark-streaming-flume-sink_2.10 + ${project.version} + target/scala-${scala.binary.version}/classes diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala new file mode 100644 index 0000000000000..dc629df4f4ac2 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.io.{ObjectOutput, ObjectInput} + +import scala.collection.JavaConversions._ + +import org.apache.spark.util.Utils +import org.apache.spark.Logging + +/** + * A simple object that provides the implementation of readExternal and writeExternal for both + * the wrapper classes for Flume-style Events. + */ +private[streaming] object EventTransformer extends Logging { + def readExternal(in: ObjectInput): (java.util.HashMap[CharSequence, CharSequence], + Array[Byte]) = { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.readFully(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.readFully(keyBuff) + val key: String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.readFully(valBuff) + val value: String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + (headers, bodyBuff) + } + + def writeExternal(out: ObjectOutput, headers: java.util.Map[CharSequence, CharSequence], + body: Array[Byte]) { + out.writeInt(body.length) + out.write(body) + val numHeaders = headers.size() + out.writeInt(numHeaders) + for ((k,v) <- headers) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 56d2886b26878..4b2ea45fb81d0 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -39,11 +39,8 @@ import org.apache.spark.streaming.receiver.Receiver import org.jboss.netty.channel.ChannelPipelineFactory import org.jboss.netty.channel.Channels -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.ChannelFactory import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.jboss.netty.handler.execution.ExecutionHandler private[streaming] class FlumeInputDStream[T: ClassTag]( diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala new file mode 100644 index 0000000000000..148262bb6771e --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume + + +import java.net.InetSocketAddress +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit, Executors} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.flume.sink._ + +/** + * A [[ReceiverInputDStream]] that can be used to read data from several Flume agents running + * [[org.apache.spark.streaming.flume.sink.SparkSink]]s. + * @param _ssc Streaming context that will execute this input stream + * @param addresses List of addresses at which SparkSinks are listening + * @param maxBatchSize Maximum size of a batch + * @param parallelism Number of parallel connections to open + * @param storageLevel The storage level to use. + * @tparam T Class type of the object of this stream + */ +private[streaming] class FlumePollingInputDStream[T: ClassTag]( + @transient _ssc: StreamingContext, + val addresses: Seq[InetSocketAddress], + val maxBatchSize: Int, + val parallelism: Int, + storageLevel: StorageLevel + ) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { + + override def getReceiver(): Receiver[SparkFlumeEvent] = { + new FlumePollingReceiver(addresses, maxBatchSize, parallelism, storageLevel) + } +} + +private[streaming] class FlumePollingReceiver( + addresses: Seq[InetSocketAddress], + maxBatchSize: Int, + parallelism: Int, + storageLevel: StorageLevel + ) extends Receiver[SparkFlumeEvent](storageLevel) with Logging { + + lazy val channelFactoryExecutor = + Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). + setNameFormat("Flume Receiver Channel Thread - %d").build()) + + lazy val channelFactory = + new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) + + lazy val receiverExecutor = Executors.newFixedThreadPool(parallelism, + new ThreadFactoryBuilder().setDaemon(true).setNameFormat("Flume Receiver Thread - %d").build()) + + private lazy val connections = new LinkedBlockingQueue[FlumeConnection]() + + override def onStart(): Unit = { + // Create the connections to each Flume agent. + addresses.foreach(host => { + val transceiver = new NettyTransceiver(host, channelFactory) + val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) + connections.add(new FlumeConnection(transceiver, client)) + }) + for (i <- 0 until parallelism) { + logInfo("Starting Flume Polling Receiver worker threads starting..") + // Threads that pull data from Flume. + receiverExecutor.submit(new Runnable { + override def run(): Unit = { + while (true) { + val connection = connections.poll() + val client = connection.client + try { + val eventBatch = client.getEventBatch(maxBatchSize) + if (!SparkSinkUtils.isErrorBatch(eventBatch)) { + // No error, proceed with processing data + val seq = eventBatch.getSequenceNumber + val events: java.util.List[SparkSinkEvent] = eventBatch.getEvents + logDebug( + "Received batch of " + events.size() + " events with sequence number: " + seq) + try { + // Convert each Flume event to a serializable SparkFlumeEvent + val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) + var j = 0 + while (j < events.size()) { + buffer += toSparkFlumeEvent(events(j)) + j += 1 + } + store(buffer) + logDebug("Sending ack for sequence number: " + seq) + // Send an ack to Flume so that Flume discards the events from its channels. + client.ack(seq) + logDebug("Ack sent for sequence number: " + seq) + } catch { + case e: Exception => + try { + // Let Flume know that the events need to be pushed back into the channel. + logDebug("Sending nack for sequence number: " + seq) + client.nack(seq) // If the agent is down, even this could fail and throw + logDebug("Nack sent for sequence number: " + seq) + } catch { + case e: Exception => logError( + "Sending Nack also failed. A Flume agent is down.") + } + TimeUnit.SECONDS.sleep(2L) // for now just leave this as a fixed 2 seconds. + logWarning("Error while attempting to store events", e) + } + } else { + logWarning("Did not receive events from Flume agent due to error on the Flume " + + "agent: " + eventBatch.getErrorMsg) + } + } catch { + case e: Exception => + logWarning("Error while reading data from Flume", e) + } finally { + connections.add(connection) + } + } + } + }) + } + } + + override def onStop(): Unit = { + logInfo("Shutting down Flume Polling Receiver") + receiverExecutor.shutdownNow() + connections.foreach(connection => { + connection.transceiver.close() + }) + channelFactory.releaseExternalResources() + } + + /** + * Utility method to convert [[SparkSinkEvent]] to [[SparkFlumeEvent]] + * @param event - Event to convert to SparkFlumeEvent + * @return - The SparkFlumeEvent generated from SparkSinkEvent + */ + private def toSparkFlumeEvent(event: SparkSinkEvent): SparkFlumeEvent = { + val sparkFlumeEvent = new SparkFlumeEvent() + sparkFlumeEvent.event.setBody(event.getBody) + sparkFlumeEvent.event.setHeaders(event.getHeaders) + sparkFlumeEvent + } +} + +/** + * A wrapper around the transceiver and the Avro IPC API. + * @param transceiver The transceiver to use for communication with Flume + * @param client The client that the callbacks are received on. + */ +private class FlumeConnection(val transceiver: NettyTransceiver, + val client: SparkFlumeProtocol.Callback) + + + diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 716db9fa76031..4b732c1592ab2 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -17,12 +17,19 @@ package org.apache.spark.streaming.flume +import java.net.InetSocketAddress + +import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaInputDStream, JavaStreamingContext, JavaDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream + object FlumeUtils { + private val DEFAULT_POLLING_PARALLELISM = 5 + private val DEFAULT_POLLING_BATCH_SIZE = 1000 + /** * Create a input stream from a Flume source. * @param ssc StreamingContext object @@ -56,7 +63,7 @@ object FlumeUtils { ): ReceiverInputDStream[SparkFlumeEvent] = { val inputStream = new FlumeInputDStream[SparkFlumeEvent]( ssc, hostname, port, storageLevel, enableDecompression) - + inputStream } @@ -105,4 +112,135 @@ object FlumeUtils { ): JavaReceiverInputDStream[SparkFlumeEvent] = { createStream(jssc.ssc, hostname, port, storageLevel, enableDecompression) } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Address of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): ReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(ssc, Seq(new InetSocketAddress(hostname, port)), storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param addresses List of InetSocketAddresses representing the hosts to connect to. + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + addresses: Seq[InetSocketAddress], + storageLevel: StorageLevel + ): ReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(ssc, addresses, storageLevel, + DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * @param addresses List of InetSocketAddresses representing the hosts to connect to. + * @param maxBatchSize Maximum number of events to be pulled from the Spark sink in a + * single RPC call + * @param parallelism Number of concurrent requests this stream should send to the sink. Note + * that having a higher number of requests concurrently being pulled will + * result in this stream using more threads + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + addresses: Seq[InetSocketAddress], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): ReceiverInputDStream[SparkFlumeEvent] = { + new FlumePollingInputDStream[SparkFlumeEvent](ssc, addresses, maxBatchSize, + parallelism, storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Hostname of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, hostname, port, StorageLevel.MEMORY_AND_DISK_SER_2) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Hostname of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, Array(new InetSocketAddress(hostname, port)), storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param addresses List of InetSocketAddresses on which the Spark Sink is running. + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + addresses: Array[InetSocketAddress], + storageLevel: StorageLevel + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, addresses, storageLevel, + DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * @param addresses List of InetSocketAddresses on which the Spark Sink is running + * @param maxBatchSize The maximum number of events to be pulled from the Spark sink in a + * single RPC call + * @param parallelism Number of concurrent requests this stream should send to the sink. Note + * that having a higher number of requests concurrently being pulled will + * result in this stream using more threads + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + addresses: Array[InetSocketAddress], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + } } diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java new file mode 100644 index 0000000000000..79c5b91654b42 --- /dev/null +++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume; + +import java.net.InetSocketAddress; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.LocalJavaStreamingContext; + +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.junit.Test; + +public class JavaFlumePollingStreamSuite extends LocalJavaStreamingContext { + @Test + public void testFlumeStream() { + // tests the API, does not actually test data receiving + InetSocketAddress[] addresses = new InetSocketAddress[] { + new InetSocketAddress("localhost", 12345) + }; + JavaReceiverInputDStream test1 = + FlumeUtils.createPollingStream(ssc, "localhost", 12345); + JavaReceiverInputDStream test2 = FlumeUtils.createPollingStream( + ssc, "localhost", 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test3 = FlumeUtils.createPollingStream( + ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test4 = FlumeUtils.createPollingStream( + ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2(), 100, 5); + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala new file mode 100644 index 0000000000000..47071d0cc4714 --- /dev/null +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.streaming.flume + +import java.net.InetSocketAddress +import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} + +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables +import org.apache.flume.event.EventBuilder + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} +import org.apache.spark.streaming.flume.sink._ + +class FlumePollingStreamSuite extends TestSuiteBase { + + val testPort = 9999 + val batchCount = 5 + val eventsPerBatch = 100 + val totalEventsPerChannel = batchCount * eventsPerBatch + val channelCapacity = 5000 + + test("flume polling test") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", testPort)), + StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + ssc.start() + + writeAndVerify(Seq(channel), ssc, outputBuffer) + assertChannelIsEmpty(channel) + sink.stop() + channel.stop() + } + + test("flume polling test multiple hosts") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val addresses = Seq(testPort, testPort + 1).map(new InetSocketAddress("localhost", _)) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, + eventsPerBatch, 5) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort + 1)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + ssc.start() + writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) + assertChannelIsEmpty(channel) + assertChannelIsEmpty(channel2) + sink.stop() + channel.stop() + } + + def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, + outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]]) { + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + channels.map(channel => { + executorCompletion.submit(new TxnSubmitter(channel, clock)) + }) + for (i <- 0 until channels.size) { + executorCompletion.take() + } + val startTime = System.currentTimeMillis() + while (outputBuffer.size < batchCount * channels.size && + System.currentTimeMillis() - startTime < 15000) { + logInfo("output.size = " + outputBuffer.size) + Thread.sleep(100) + } + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < 15000, "Operation timed out after " + timeTaken + " ms") + logInfo("Stopping context") + ssc.stop() + + val flattenedBuffer = outputBuffer.flatten + assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + + String.valueOf(i)).getBytes("utf-8"), + Map[String, String]("test-" + i.toString -> "header")) + var found = false + var j = 0 + while (j < flattenedBuffer.size && !found) { + val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") + if (new String(eventToVerify.getBody, "utf-8") == strToCompare && + eventToVerify.getHeaders.get("test-" + i.toString) + .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { + found = true + counter += 1 + } + j += 1 + } + } + assert(counter === totalEventsPerChannel * channels.size) + } + + def assertChannelIsEmpty(channel: MemoryChannel) = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining"); + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) + } + + private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( + "utf-8"), + Map[String, String]("test-" + t.toString -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + clock.addToTime(batchDuration.milliseconds) + } + null + } + } +} diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4762c50685a93..daf03360bc5f5 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-kafka_2.10 - streaming-kafka + streaming-kafka jar Spark Project External Kafka @@ -80,6 +80,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 38095e88dcea9..e20e2c8f26991 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import scala.collection.Map -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} import java.util.Properties import java.util.concurrent.Executors @@ -48,8 +48,8 @@ private[streaming] class KafkaInputDStream[ K: ClassTag, V: ClassTag, - U <: Decoder[_]: Manifest, - T <: Decoder[_]: Manifest]( + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -66,8 +66,8 @@ private[streaming] class KafkaReceiver[ K: ClassTag, V: ClassTag, - U <: Decoder[_]: Manifest, - T <: Decoder[_]: Manifest]( + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel @@ -103,10 +103,10 @@ class KafkaReceiver[ tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) } - val keyDecoder = manifest[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[K]] - val valueDecoder = manifest[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[V]] diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 86bb91f362d29..48668f763e41e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -65,7 +65,7 @@ object KafkaUtils { * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ - def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: Manifest, T <: Decoder[_]: Manifest]( + def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -89,8 +89,6 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - implicit val cmt: ClassTag[String] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) } @@ -111,8 +109,6 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - implicit val cmt: ClassTag[String] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } @@ -140,13 +136,11 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[K, V] = { - implicit val keyCmt: ClassTag[K] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] - implicit val valueCmt: ClassTag[V] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] + implicit val keyCmt: ClassTag[K] = ClassTag(keyTypeClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueTypeClass) - implicit val keyCmd: Manifest[U] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[U]] - implicit val valueCmd: Manifest[T] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[T]] + implicit val keyCmd: ClassTag[U] = ClassTag(keyDecoderClass) + implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 32c530e600ce0..dc48a08c93de2 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-mqtt_2.10 - streaming-mqtt + streaming-mqtt jar Spark Project External MQTT @@ -67,6 +67,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 637adb0f00da0..b93ad016f84f0 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-twitter_2.10 - streaming-twitter + streaming-twitter jar Spark Project External Twitter @@ -62,6 +62,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index e4d758a04a4cd..22c1fff23d9a2 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-zeromq_2.10 - streaming-zeromq + streaming-zeromq jar Spark Project External ZeroMQ @@ -62,6 +62,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 3eade411b38b7..5308bb4e440ea 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -50,6 +50,11 @@ ${project.version} test-jar + + junit + junit + test + com.novocode junit-interface diff --git a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index d03d7774e8c80..3b1880e143513 100644 --- a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -82,5 +82,9 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry, override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/graphx/pom.xml b/graphx/pom.xml index 7e3bcf29dcfbc..6dd52fc618b1e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-graphx_2.10 - graphx + graphx jar Spark Project GraphX diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index 5318b8da6412a..714f3b81c9dad 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.{ShuffledRDD, RDD} private[graphx] class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner) + val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner) // Set a custom serializer if the data is of int or double type. if (classTag[VD] == ClassTag.Int) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index a565d3b28bf52..b27485953f719 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -33,7 +33,7 @@ private[graphx] class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage]( + new ShuffledRDD[VertexId, Int, Int]( self, partitioner).setSerializer(new RoutingTableMessageSerializer) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 635514f09ece0..60149548ab852 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -100,8 +100,10 @@ object GraphGenerators { */ private def sampleLogNormal(mu: Double, sigma: Double, maxVal: Int): Int = { val rand = new Random() - val m = math.exp(mu + (sigma * sigma) / 2.0) - val s = math.sqrt((math.exp(sigma*sigma) - 1) * math.exp(2*mu + sigma*sigma)) + val sigmaSq = sigma * sigma + val m = math.exp(mu + sigmaSq / 2.0) + // expm1 is exp(m)-1 with better accuracy for tiny m + val s = math.sqrt(math.expm1(sigmaSq) * math.exp(2*mu + sigmaSq)) // Z ~ N(0, 1) var X: Double = maxVal diff --git a/make-distribution.sh b/make-distribution.sh index c08093f46b61f..1441497b3995a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -128,7 +128,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then if [[ ! $REPLY =~ ^[Yy]$ ]]; then echo "Okay, exiting." exit 1 - fi + fi fi if [ "$NAME" == "none" ]; then @@ -150,7 +150,7 @@ else fi # Build uber fat JAR -cd $FWDIR +cd "$FWDIR" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" @@ -173,7 +173,7 @@ cp $FWDIR/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" +cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" if [ "$SPARK_HIVE" == "true" ]; then cp $FWDIR/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" @@ -199,7 +199,7 @@ cp -r "$FWDIR/ec2" "$DISTDIR" # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then - TACHYON_VERSION="0.4.1" + TACHYON_VERSION="0.5.0" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` diff --git a/mllib/pom.xml b/mllib/pom.xml index 92b07e2357db1..45046eca5b18c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-mllib_2.10 - mllib + mllib jar Spark Project ML Library @@ -60,6 +60,10 @@ junit junit + + org.apache.commons + commons-math3 + @@ -72,6 +76,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 82a993d03119b..7d912737b8f0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -20,15 +20,20 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -226,7 +231,7 @@ class PythonMLLibAPI extends Serializable { jsc: JavaSparkContext, path: String, minPartitions: Int): JavaRDD[Array[Byte]] = - MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint).toJavaRDD() + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint) private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, @@ -466,4 +471,130 @@ class PythonMLLibAPI extends Serializable { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + + /** + * Java stub for mllib Statistics.corr(X: RDD[Vector], method: String). + * Returns the correlation matrix serialized into a byte array understood by deserializers in + * pyspark. + */ + def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = { + val inputMatrix = X.rdd.map(deserializeDoubleVector(_)) + val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method)) + serializeDoubleMatrix(to2dArray(result)) + } + + /** + * Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String). + */ + def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = { + val xDeser = x.rdd.map(deserializeDouble(_)) + val yDeser = y.rdd.map(deserializeDouble(_)) + Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method)) + } + + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark + private def getCorrNameOrDefault(method: String) = { + if (method == null) CorrelationNames.defaultCorrName else method + } + + // Reformat a Matrix into Array[Array[Double]] for serialization + private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { + val values = matrix.toArray + Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) + } + + // Used by the *RDD methods to get default seed if not passed in from pyspark + private def getSeedOrDefault(seed: java.lang.Long): Long = { + if (seed == null) Utils.random.nextLong else seed + } + + // Used by *RDD methods to get default numPartitions if not passed in from pyspark + private def getNumPartitionsOrDefault(numPartitions: java.lang.Integer, + jsc: JavaSparkContext): Int = { + if (numPartitions == null) { + jsc.sc.defaultParallelism + } else { + numPartitions + } + } + + // Note: for the following methods, numPartitions and seed are boxed to allow nulls to be passed + // in for either argument from pyspark + + /** + * Java stub for Python mllib RandomRDDGenerators.uniformRDD() + */ + def uniformRDD(jsc: JavaSparkContext, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.normalRDD() + */ + def normalRDD(jsc: JavaSparkContext, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.poissonRDD() + */ + def poissonRDD(jsc: JavaSparkContext, + mean: Double, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.uniformVectorRDD() + */ + def uniformVectorRDD(jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.normalVectorRDD() + */ + def normalVectorRDD(jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.poissonVectorRDD() + */ + def poissonVectorRDD(jsc: JavaSparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala new file mode 100644 index 0000000000000..0f6d5809e098f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * Maps a sequence of terms to their term frequencies using the hashing trick. + * + * @param numFeatures number of features (default: 1000000) + */ +@Experimental +class HashingTF(val numFeatures: Int) extends Serializable { + + def this() = this(1000000) + + /** + * Returns the index of the input term. + */ + def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) + + /** + * Transforms the input document into a sparse term frequency vector. + */ + def transform(document: Iterable[_]): Vector = { + val termFrequencies = mutable.HashMap.empty[Int, Double] + document.foreach { term => + val i = indexOf(term) + termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0) + } + Vectors.sparse(numFeatures, termFrequencies.toSeq) + } + + /** + * Transforms the input document into a sparse term frequency vector (Java version). + */ + def transform(document: JavaIterable[_]): Vector = { + transform(document.asScala) + } + + /** + * Transforms the input document to term frequency vectors. + */ + def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { + dataset.map(this.transform) + } + + /** + * Transforms the input document to term frequency vectors (Java version). + */ + def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { + dataset.rdd.map(this.transform).toJavaRDD() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala new file mode 100644 index 0000000000000..7ed611a857acc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Inverse document frequency (IDF). + * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total + * number of documents and `d(t)` is the number of documents that contain term `t`. + */ +@Experimental +class IDF { + + // TODO: Allow different IDF formulations. + + private var brzIdf: BDV[Double] = _ + + /** + * Computes the inverse document frequency. + * @param dataset an RDD of term frequency vectors + */ + def fit(dataset: RDD[Vector]): this.type = { + brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + seqOp = (df, v) => df.add(v), + combOp = (df1, df2) => df1.merge(df2) + ).idf() + this + } + + /** + * Computes the inverse document frequency. + * @param dataset a JavaRDD of term frequency vectors + */ + def fit(dataset: JavaRDD[Vector]): this.type = { + fit(dataset.rdd) + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors. + * @param dataset an RDD of term frequency vectors + * @return an RDD of TF-IDF vectors + */ + def transform(dataset: RDD[Vector]): RDD[Vector] = { + if (!initialized) { + throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") + } + val theIdf = brzIdf + val bcIdf = dataset.context.broadcast(theIdf) + dataset.mapPartitions { iter => + val thisIdf = bcIdf.value + iter.map { v => + val n = v.size + v match { + case sv: SparseVector => + val nnz = sv.indices.size + val newValues = new Array[Double](nnz) + var k = 0 + while (k < nnz) { + newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) + k += 1 + } + Vectors.sparse(n, sv.indices, newValues) + case dv: DenseVector => + val newValues = new Array[Double](n) + var j = 0 + while (j < n) { + newValues(j) = dv.values(j) * thisIdf(j) + j += 1 + } + Vectors.dense(newValues) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } + } + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). + * @param dataset a JavaRDD of term frequency vectors + * @return a JavaRDD of TF-IDF vectors + */ + def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(dataset.rdd).toJavaRDD() + } + + /** Returns the IDF vector. */ + def idf(): Vector = { + if (!initialized) { + throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") + } + Vectors.fromBreeze(brzIdf) + } + + private def initialized: Boolean = brzIdf != null +} + +private object IDF { + + /** Document frequency aggregator. */ + class DocumentFrequencyAggregator extends Serializable { + + /** number of documents */ + private var m = 0L + /** document frequency vector */ + private var df: BDV[Long] = _ + + /** Adds a new document. */ + def add(doc: Vector): this.type = { + if (isEmpty) { + df = BDV.zeros(doc.size) + } + doc match { + case sv: SparseVector => + val nnz = sv.indices.size + var k = 0 + while (k < nnz) { + if (sv.values(k) > 0) { + df(sv.indices(k)) += 1L + } + k += 1 + } + case dv: DenseVector => + val n = dv.size + var j = 0 + while (j < n) { + if (dv.values(j) > 0.0) { + df(j) += 1L + } + j += 1 + } + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + m += 1L + this + } + + /** Merges another. */ + def merge(other: DocumentFrequencyAggregator): this.type = { + if (!other.isEmpty) { + m += other.m + if (df == null) { + df = other.df.copy + } else { + df += other.df + } + } + this + } + + private def isEmpty: Boolean = m == 0L + + /** Returns the current IDF vector. */ + def idf(): BDV[Double] = { + if (isEmpty) { + throw new IllegalStateException("Haven't seen any document yet.") + } + val n = df.length + val inv = BDV.zeros[Double](n) + var j = 0 + while (j < n) { + inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) + j += 1 + } + inv + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 8c2b044ea73f2..58c1322757a43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} /** @@ -79,7 +80,7 @@ class RowMatrix( private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = { val n = numCols().toInt val vbr = rows.context.broadcast(v) - rows.aggregate(BDV.zeros[Double](n))( + rows.treeAggregate(BDV.zeros[Double](n))( seqOp = (U, r) => { val rBrz = r.toBreeze val a = rBrz.dot(vbr.value) @@ -91,9 +92,7 @@ class RowMatrix( s"Do not support vector operation from type ${rBrz.getClass.getName}.") } U - }, - combOp = (U1, U2) => U1 += U2 - ) + }, combOp = (U1, U2) => U1 += U2) } /** @@ -104,13 +103,11 @@ class RowMatrix( val nt: Int = n * (n + 1) / 2 // Compute the upper triangular part of the gram matrix. - val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))( + val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { RowMatrix.dspr(1.0, v, U.data) U - }, - combOp = (U1, U2) => U1 += U2 - ) + }, combOp = (U1, U2) => U1 += U2) RowMatrix.triuToFull(n, GU.data) } @@ -290,9 +287,10 @@ class RowMatrix( s"We need at least $mem bytes of memory.") } - val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( + val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), - combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2) + combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => + (s1._1 + s2._1, s1._2 += s2._2) ) updateNumRows(m) @@ -353,10 +351,9 @@ class RowMatrix( * Computes column-wise summary statistics. */ def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { - val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)( + val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), - (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - ) + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) updateNumRows(summary.count) summary } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 679842f831c2a..9d82f011e674a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -68,9 +68,9 @@ class LogisticGradient extends Gradient { val gradient = brzData * gradientMultiplier val loss = if (label > 0) { - math.log(1 + math.exp(margin)) + math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p } else { - math.log(1 + math.exp(margin)) - margin + math.log1p(math.exp(margin)) - margin } (Vectors.fromBreeze(gradient), loss) @@ -89,9 +89,9 @@ class LogisticGradient extends Gradient { brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) if (label > 0) { - math.log(1 + math.exp(margin)) + math.log1p(math.exp(margin)) } else { - math.log(1 + math.exp(margin)) - margin + math.log1p(math.exp(margin)) - margin } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 9fd760bf78083..356aa949afcf5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -25,6 +25,7 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Class used to solve an optimization problem using Gradient Descent. @@ -177,7 +178,7 @@ object GradientDescent extends Logging { // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .aggregate((BDV.zeros[Double](n), 0.0))( + .treeAggregate((BDV.zeros[Double](n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) (grad, loss + l) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 179cd4a3f1625..26a2b62e76ed0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -26,6 +26,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.rdd.RDDFunctions._ /** * :: DeveloperApi :: @@ -199,7 +200,7 @@ object LBFGS extends Logging { val n = weights.length val bcWeights = data.context.broadcast(weights) - val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))( + val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala index d7ee2d3f46846..021d651d4dbaa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala @@ -26,14 +26,17 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. */ @Experimental object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -49,7 +52,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n, p).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -63,9 +69,12 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. * sc.defaultParallelism used for the number of partitions in the RDD. * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n).map(v => a + (b - a) * v)`. + * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. @@ -77,7 +86,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -93,7 +105,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -107,9 +122,12 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. * sc.defaultParallelism used for the number of partitions in the RDD. * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n).map(v => mean + sigma * v)`. + * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). @@ -121,7 +139,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. @@ -142,7 +160,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. @@ -157,7 +175,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. @@ -172,7 +190,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * * @param sc SparkContext used to create the RDD. * @param generator DistributionGenerator used to populate the RDD. @@ -192,7 +210,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * * @param sc SparkContext used to create the RDD. * @param generator DistributionGenerator used to populate the RDD. @@ -210,7 +228,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. @@ -229,7 +247,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * * @param sc SparkContext used to create the RDD. @@ -251,14 +269,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. */ @Experimental def uniformVectorRDD(sc: SparkContext, @@ -270,14 +288,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. */ @Experimental def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { @@ -286,7 +304,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. @@ -294,7 +312,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, @@ -308,14 +326,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, @@ -327,14 +345,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { @@ -343,7 +361,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -352,7 +370,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -367,7 +385,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -375,7 +393,7 @@ object RandomRDDGenerators { * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -388,7 +406,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * sc.defaultParallelism used for the number of partitions in the RDD. * @@ -396,7 +414,7 @@ object RandomRDDGenerators { * @param mean Mean, or lambda, for the Poisson distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -408,7 +426,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * * @param sc SparkContext used to create the RDD. @@ -417,7 +435,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, @@ -431,7 +449,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * * @param sc SparkContext used to create the RDD. @@ -439,7 +457,7 @@ object RandomRDDGenerators { * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, @@ -452,7 +470,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * sc.defaultParallelism used for the number of partitions in the RDD. * @@ -460,7 +478,7 @@ object RandomRDDGenerators { * @param generator DistributionGenerator used to populate the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 365b5e75d7f75..b5e403bc8c14d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,7 +20,10 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.HashPartitioner +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. @@ -44,6 +47,69 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { new SlidingRDD[T](self, windowSize) } } + + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = self.context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (self.partitions.size == 0) { + return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = self.context.clean(seqOp) + val cleanCombOp = self.context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } } private[mllib] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 5356790cb5339..36d262fed425a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -255,6 +255,9 @@ class ALS private ( rank, lambda, alpha, YtY) previousProducts.unpersist() logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) + if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { + products.checkpoint() + } products.setName(s"products-$iter").persist() val XtX = Some(sc.broadcast(computeYtY(products))) val previousUsers = users @@ -268,6 +271,9 @@ class ALS private ( logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, rank, lambda, alpha, YtY = None) + if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { + products.checkpoint() + } products.setName(s"products-$iter") logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, @@ -284,8 +290,8 @@ class ALS private ( val usersOut = unblockFactors(users, userOutLinks) val productsOut = unblockFactors(products, productOutLinks) - usersOut.setName("usersOut").persist() - productsOut.setName("productsOut").persist() + usersOut.setName("usersOut").persist(StorageLevel.MEMORY_AND_DISK) + productsOut.setName("productsOut").persist(StorageLevel.MEMORY_AND_DISK) // Materialize usersOut and productsOut. usersOut.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 899286d235a9d..a1a76fcbe9f9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -65,6 +65,48 @@ class MatrixFactorizationModel private[mllib] ( } } + /** + * Recommends products to a user. + * + * @param user the user to recommend products to + * @param num how many products to return. The number returned may be less than this. + * @return [[Rating]] objects, each of which contains the given user ID, a product ID, and a + * "score" in the rating field. Each represents one recommended product, and they are sorted + * by score, decreasing. The first returned is the one predicted to be most strongly + * recommended to the user. The score is an opaque value that indicates how strongly + * recommended the product is. + */ + def recommendProducts(user: Int, num: Int): Array[Rating] = + recommend(userFeatures.lookup(user).head, productFeatures, num) + .map(t => Rating(user, t._1, t._2)) + + /** + * Recommends users to a product. That is, this returns users who are most likely to be + * interested in a product. + * + * @param product the product to recommend users to + * @param num how many users to return. The number returned may be less than this. + * @return [[Rating]] objects, each of which contains a user ID, the given product ID, and a + * "score" in the rating field. Each represents one recommended user, and they are sorted + * by score, decreasing. The first returned is the one predicted to be most strongly + * recommended to the product. The score is an opaque value that indicates how strongly + * recommended the user is. + */ + def recommendUsers(product: Int, num: Int): Array[Rating] = + recommend(productFeatures.lookup(product).head, userFeatures, num) + .map(t => Rating(t._1, product, t._2)) + + private def recommend( + recommendToFeatures: Array[Double], + recommendableFeatures: RDD[(Int, Array[Double])], + num: Int): Array[(Int, Double)] = { + val recommendToVector = new DoubleMatrix(recommendToFeatures) + val scored = recommendableFeatures.map { case (id,features) => + (id, recommendToVector.dot(new DoubleMatrix(features))) + } + scored.top(num)(Ordering.by(_._2)) + } + /** * :: DeveloperApi :: * Predict the rating of many users for many products. @@ -80,6 +122,4 @@ class MatrixFactorizationModel private[mllib] ( predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) } - // TODO: Figure out what other good bulk prediction methods would look like. - // Probably want a way to get the top users for a product or vice-versa. } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 68f3867ba6c11..f416a9fbb323d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -23,23 +23,26 @@ import org.apache.spark.mllib.stat.correlation.Correlations import org.apache.spark.rdd.RDD /** - * API for statistical functions in MLlib + * API for statistical functions in MLlib. */ @Experimental object Statistics { /** + * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. - * Returns NaN if either vector has 0 variance. + * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. */ + @Experimental def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** + * :: Experimental :: * Compute the correlation matrix for the input RDD of Vectors using the specified method. - * Methods currently supported: `pearson` (default), `spearman` + * Methods currently supported: `pearson` (default), `spearman`. * * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], @@ -51,28 +54,39 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. */ + @Experimental def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** + * :: Experimental :: * Compute the Pearson correlation for the input RDDs. - * Columns with 0 covariance produce NaN entries in the correlation matrix. + * Returns NaN if either vector has 0 variance. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ + @Experimental def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** + * :: Experimental :: * Compute the correlation for the input RDDs using the specified method. - * Methods currently supported: pearson (default), spearman + * Methods currently supported: `pearson` (default), `spearman`. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` *@return A Double containing the correlation between the two input RDD[Double]s using the * specified method. */ + @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala index f23393d3da257..1fb8d7b3d4f32 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala @@ -49,43 +49,48 @@ private[stat] trait Correlation { } /** - * Delegates computation to the specific correlation object based on the input method name - * - * Currently supported correlations: pearson, spearman. - * After new correlation algorithms are added, please update the documentation here and in - * Statistics.scala for the correlation APIs. - * - * Maintains the default correlation type, pearson + * Delegates computation to the specific correlation object based on the input method name. */ private[stat] object Correlations { - // Note: after new types of correlations are implemented, please update this map - val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation)) - val defaultCorrName: String = "pearson" - val defaultCorr: Correlation = nameToObjectMap(defaultCorrName) - - def corr(x: RDD[Double], y: RDD[Double], method: String = defaultCorrName): Double = { + def corr(x: RDD[Double], + y: RDD[Double], + method: String = CorrelationNames.defaultCorrName): Double = { val correlation = getCorrelationFromName(method) correlation.computeCorrelation(x, y) } - def corrMatrix(X: RDD[Vector], method: String = defaultCorrName): Matrix = { + def corrMatrix(X: RDD[Vector], + method: String = CorrelationNames.defaultCorrName): Matrix = { val correlation = getCorrelationFromName(method) correlation.computeCorrelationMatrix(X) } - /** - * Match input correlation name with a known name via simple string matching - * - * private to stat for ease of unit testing - */ - private[stat] def getCorrelationFromName(method: String): Correlation = { + // Match input correlation name with a known name via simple string matching. + def getCorrelationFromName(method: String): Correlation = { try { - nameToObjectMap(method) + CorrelationNames.nameToObjectMap(method) } catch { case nse: NoSuchElementException => throw new IllegalArgumentException("Unrecognized method name. Supported correlations: " - + nameToObjectMap.keys.mkString(", ")) + + CorrelationNames.nameToObjectMap.keys.mkString(", ")) } } } + +/** + * Maintains supported and default correlation names. + * + * Currently supported correlations: `pearson`, `spearman`. + * Current default correlation: `pearson`. + * + * After new correlation algorithms are added, please update the documentation here and in + * Statistics.scala for the correlation APIs. + */ +private[mllib] object CorrelationNames { + + // Note: after new types of correlations are implemented, please update this map. + val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation)) + val defaultCorrName: String = "pearson" + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 1f7de630e778c..9bd0c2cd05de4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -89,20 +89,18 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => // add an extra element to signify the end of the list so that flatMap can flush the last // batch of duplicates - val padded = iter ++ - Iterator[((Double, Long), Long)](((Double.NaN, -1L), -1L)) - var lastVal = 0.0 - var firstRank = 0.0 - val idBuffer = new ArrayBuffer[Long]() + val end = -1L + val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) + val firstEntry = padded.next() + var lastVal = firstEntry._1._1 + var firstRank = firstEntry._2.toDouble + val idBuffer = ArrayBuffer(firstEntry._1._2) padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != Long.MinValue) { + if (v == lastVal && id != end) { idBuffer += id Iterator.empty } else { - val entries = if (idBuffer.size == 0) { - // edge case for the first value matching the initial value of lastVal - Iterator.empty - } else if (idBuffer.size == 1) { + val entries = if (idBuffer.size == 1) { Iterator((idBuffer(0), firstRank)) } else { val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad32e3f4560fe..382e76a9b7cba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * :: Experimental :: - * A class that implements a decision tree algorithm for classification and regression. It - * supports both continuous and categorical features. + * A class which implements a decision tree learning algorithm for classification and regression. + * It supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. @@ -42,8 +42,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /** * Method to train a decision tree model over an RDD - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @return a DecisionTreeModel that can be used for prediction + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { @@ -60,7 +60,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -100,7 +100,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var level = 0 var break = false - while (level < maxDepth && !break) { + while (level <= maxDepth && !break) { logDebug("#####################################") logDebug("level = " + level) @@ -152,7 +152,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -173,7 +173,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { + if (level < maxDepth) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { @@ -197,17 +197,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. The parameters for the algorithm are specified using the strategy parameter. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).train(input) @@ -219,12 +218,14 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @return a DecisionTreeModel that can be used for prediction + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -241,13 +242,15 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -266,11 +269,13 @@ object DecisionTree extends Serializable with Logging { * 1 to denote the two classes. The method also supports categorical features inputs where the * number of categories can specified using the categoricalFeaturesInfo option. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles @@ -279,7 +284,7 @@ object DecisionTree extends Serializable with Logging { * an entry (n -> k) implies the feature n is categorical with k * categories 0, 1, 2, ... , k-1. It's important to note that * features are zero-indexed. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -301,11 +306,10 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -348,11 +352,10 @@ object DecisionTree extends Serializable with Logging { /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -373,7 +376,7 @@ object DecisionTree extends Serializable with Logging { groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* - * The high-level description for the best split optimizations are noted here. + * The high-level descriptions of the best split optimizations are noted here. * * *Level-wise training* * We perform bin calculations for all nodes at the given level to avoid making multiple @@ -396,18 +399,27 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // common calculations for multiple nested methods + // Common calculations for multiple nested methods: + + // numNodes: Number of nodes in this (level of tree, group), + // where nodes at deeper (larger) levels may be divided into groups. val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size logDebug("numFeatures = " + numFeatures) + + // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) + val isMulticlassClassification = strategy.isMulticlassClassification logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlassClassificationWithCategoricalFeatures = strategy.isMulticlassWithCategoricalFeatures logDebug("isMultiClassWithCategoricalFeatures = " + @@ -465,10 +477,13 @@ object DecisionTree extends Serializable with Logging { } /** - * Find bin for one feature. + * Find bin for one (labeledPoint, feature). */ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -483,7 +498,7 @@ object DecisionTree extends Serializable with Logging { val bin = binForFeatures(mid) val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)){ + if ((lowThreshold < feature) && (highThreshold >= feature)) { return mid } else if (lowThreshold >= feature) { @@ -507,41 +522,51 @@ object DecisionTree extends Serializable with Logging { } /** - * Sequential search helper method to find bin for categorical feature. + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). */ - def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureValue = labeledPoint.features(featureIndex) var binIndex = 0 - while (binIndex < numCategoricalBins) { + while (binIndex < featureCategories) { val bin = bins(featureIndex)(binIndex) val categories = bin.highSplit.categories - val features = labeledPoint.features - if (categories.contains(features(featureIndex))) { + if (categories.contains(featureValue)) { return binIndex } binIndex += 1 } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } -1 } if (isFeatureContinuous) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for continuous variable.") } binIndex } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { - sequentialBinSearchForOrderedCategoricalFeatureInClassification() + sequentialBinSearchForOrderedCategoricalFeature() } } - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for categorical variable.") } binIndex @@ -555,6 +580,14 @@ object DecisionTree extends Serializable with Logging { * where b_ij is an integer between 0 and numBins - 1 for regressions and binary * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. + * + * For unordered features, the "bin index" returned is actually the feature value (category). + * + * @return Array of size 1 + numFeatures * numNodes, where + * arr(0) = label for labeledPoint, and + * arr(1 + numFeatures * nodeIndex + featureIndex) = + * bin index for this labeledPoint + * (or InvalidBinIndex if labeledPoint is not handled by this node) */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. @@ -598,58 +631,85 @@ object DecisionTree extends Serializable with Logging { // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) - def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int) = { - + /** + * Increment aggregate in location for (node, feature, bin, label). + * + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. + */ + def updateBinForOrderedFeature( + arr: Array[Double], + agg: Array[Double], + nodeIndex: Int, + label: Double, + featureIndex: Int): Unit = { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + val aggIndex = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + numClasses * arr(arrIndex).toInt + + label.toInt + agg(aggIndex) += 1 } - def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], - label: Double, agg: Array[Double], rightChildShift: Int) = { + /** + * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), + * where [bins] ranges over all bins. + * Updates left or right side of aggregate depending on split. + * + * @param arr arr(0) = label. + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param agg Indexed by (left/right, node, feature, bin, label) + * where label is the least significant bit. + * The left/right specifier is a 0/1 index indicating left/right child info. + * @param rightChildShift Offset for right side of agg. + */ + def updateBinForUnorderedFeature( + nodeIndex: Int, + featureIndex: Int, + arr: Array[Double], + label: Double, + agg: Array[Double], + rightChildShift: Int): Unit = { // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex + val arrIndex = 1 + numFeatures * nodeIndex + featureIndex + val featureValue = arr(arrIndex).toInt // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val aggShift = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + label.toInt // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { - val labelInt = label.toInt - if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + 1 + val aggIndex = aggShift + binIndex * numClasses + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(aggIndex) += 1 } else { - agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + 1 + agg(rightChildShift + aggIndex) += 1 } binIndex += 1 } } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. */ - def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -671,17 +731,21 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numClasses * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * For ordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. + * For unordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). + * @param agg Array storing aggregate calculation. + * For ordered features, this is of size: + * numClasses * numBins * numFeatures * numNodes. + * For unordered features, this is of size: + * 2 * numClasses * numBins * numFeatures * numNodes. */ - def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -717,16 +781,17 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * Performs a sequential aggregation over a partition for regression. + * For l nodes, k features, * the count, sum, sum of squares of one of the p bins is incremented. * - * @param agg Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for regression + * @param agg Array storing aggregate calculation, updated by this function. + * Size: 3 * numBins * numFeatures * numNodes + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -757,14 +822,30 @@ object DecisionTree extends Serializable with Logging { /** * Performs a sequential aggregation over a partition. + * For l nodes, k features, + * For classification: + * Either the left count or the right count of one of the bins is + * incremented based upon whether the feature is classified as 0 or 1. + * For regression: + * The count, sum, sum of squares of one of the bins is incremented. + * + * @param agg Array storing aggregate calculation, updated by this function. + * Size for classification: + * numClasses * numBins * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. + * Size for regression: + * 3 * numBins * numFeatures * numNodes. + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { case Classification => if(isMulticlassClassificationWithCategoricalFeatures) { - unorderedClassificationBinSeqOp(arr, agg) + multiclassWithCategoricalBinSeqOp(arr, agg) } else { - orderedClassificationBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(arr, agg) } case Regression => regressionBinSeqOp(arr, agg) } @@ -815,20 +896,10 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - var classIndex = 0 - val leftCounts: Array[Double] = new Array[Double](numClasses) - val rightCounts: Array[Double] = new Array[Double](numClasses) - var leftTotalCount = 0.0 - var rightTotalCount = 0.0 - while (classIndex < numClasses) { - val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) - val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) - leftCounts(classIndex) = leftClassCount - leftTotalCount += leftClassCount - rightCounts(classIndex) = rightClassCount - rightTotalCount += rightClassCount - classIndex += 1 - } + val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) + val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) + val leftTotalCount = leftCounts.sum + val rightTotalCount = rightCounts.sum val impurity = { if (level > 0) { @@ -845,33 +916,17 @@ object DecisionTree extends Serializable with Logging { } } - if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) - } - if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) - } - - val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) - val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - - val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) - val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } // Sum of count for each label - val leftRightCounts: Array[Double] - = leftCounts.zip(rightCounts) - .map{case (leftCount, rightCount) => leftCount + rightCount} + val leftRightCounts: Array[Double] = + leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => + leftCount + rightCount + } def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { @@ -885,6 +940,22 @@ object DecisionTree extends Serializable with Logging { val predict = indexOfLargestArrayElement(leftRightCounts) val prob = leftRightCounts(predict) / totalCount + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(leftCounts, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(rightCounts, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) @@ -937,10 +1008,18 @@ object DecisionTree extends Serializable with Logging { /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], - * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, - * (numBins - 1), numClasses) + * @param binData Aggregate array slice from getBinDataForNode. + * For classification: + * For unordered features, this is leftChildData ++ rightChildData, + * each of which is indexed by (feature, split/bin, class), + * with class being the least significant bit. + * For ordered features, this is of size numClasses * numBins * numFeatures. + * For regression: + * This is of size 2 * numFeatures * numBins. + * @return (leftNodeAgg, rightNodeAgg) pair of arrays. + * For classification, each array is of size (numFeatures, (numBins - 1), numClasses). + * For regression, each array is of size (numFeatures, (numBins - 1), 3). + * */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { @@ -983,6 +1062,11 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Reshape binData for this feature. + * Indexes binData as (feature, split, class) with class as the least significant bit. + * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value + */ def findAggForUnorderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1052,7 +1136,7 @@ object DecisionTree extends Serializable with Logging { val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures){ + if (isMulticlassClassificationWithCategoricalFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -1107,7 +1191,7 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. - * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param binData Bin data slice for this node, given by getBinDataForNode. * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ @@ -1133,7 +1217,7 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex : Double = { + val maxSplitIndex: Double = { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { numBins - 1 @@ -1162,8 +1246,8 @@ object DecisionTree extends Serializable with Logging { (bestFeatureIndex, bestSplitIndex, bestGainStats) } + logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } @@ -1214,8 +1298,17 @@ object DecisionTree extends Serializable with Logging { bestSplits } - private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { + /** + * Get the number of values to be stored per node in the bin aggregates. + * + * @param numBins Number of bins = 1 + number of possible splits. + */ + private def getElementsPerNode( + numFeatures: Int, + numBins: Int, + numClasses: Int, + isMulticlassClassificationWithCategoricalFeatures: Boolean, + algo: Algo): Int = { algo match { case Classification => if (isMulticlassClassificationWithCategoricalFeatures) { @@ -1228,18 +1321,40 @@ object DecisionTree extends Serializable with Logging { } /** - * Returns split and bins for decision tree calculation. - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * There are 2^(maxFeatureValue - 1) - 1 splits. + * (b) For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one split per category. + + * Categorical case (a) features are called unordered features. + * Other cases are called ordered features. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + * parameters for construction the DecisionTree + * @return A tuple of (splits,bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numBins - 1). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() // Find the number of features by looking at the first sample @@ -1271,7 +1386,8 @@ object DecisionTree extends Serializable with Logging { logDebug("fraction of data used for calculating quantiles = " + fraction) // sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val sampledInput = + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length val stride: Double = numSamples.toDouble / numBins @@ -1286,7 +1402,7 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { // Check whether the feature is continuous. val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { @@ -1294,8 +1410,10 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) for (index <- 0 until numBins - 1) { - val sampleIndex = (index + 1) * stride.toInt - val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + val sampleIndex = index * stride.toInt + // Set threshold halfway in between 2 samples. + val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + val split = new Split(featureIndex, threshold, Continuous, List()) splits(featureIndex)(index) = split } } else { // Categorical feature @@ -1304,8 +1422,10 @@ object DecisionTree extends Serializable with Logging { = numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + // classification that satisfy the space constraint. + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -1330,8 +1450,13 @@ object DecisionTree extends Serializable with Logging { } index += 1 } - } else { - + } else { // ordered feature + /* For a given categorical feature, use a subsample of the data + * to choose how to arrange possible splits. + * This examines each category and computes a centroid. + * These centroids are later used to sort the possible splits. + * centroidForCategories is a mapping: category (for the given feature) --> centroid + */ val centroidForCategories = { if (isMulticlassClassification) { // For categorical variables in multiclass classification, @@ -1341,7 +1466,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1352,7 +1477,7 @@ object DecisionTree extends Serializable with Logging { } } - logDebug("centriod for categories = " + centroidForCategories.mkString(",")) + logDebug("centroid for categories = " + centroidForCategories.mkString(",")) // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() @@ -1367,7 +1492,7 @@ object DecisionTree extends Serializable with Logging { // bins sorted by centroids val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) var categoriesForSplit = List[Double]() categoriesSortedByCentroid.iterator.zipWithIndex.foreach { @@ -1397,7 +1522,7 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1){ + for (index <- 1 until numBins - 1) { val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7c027ac2fda6b..5c65b537b6867 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,7 +27,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * Stores all the configuration options for tree construction * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value is 2 * leads to binary classification * @param maxBins maximum number of bins used for splitting features @@ -52,7 +53,9 @@ class Strategy ( val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), val maxMemoryInMB: Int = 128) extends Serializable { - require(numClassesForClassification >= 2) + if (algo == Classification) { + require(numClassesForClassification >= 2) + } val isMulticlassClassification = numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index a0e2d91762782..9297c20596527 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -34,10 +34,13 @@ object Entropy extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 @@ -58,6 +61,7 @@ object Entropy extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 48144b5e6d1e4..2874bcf496484 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -33,10 +33,13 @@ object Gini extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 1.0 var classIndex = 0 @@ -54,6 +57,7 @@ object Gini extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 7b2a9320cc21d..92b0c7b4a6fbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -31,7 +31,7 @@ trait Impurity extends Serializable { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -42,7 +42,7 @@ trait Impurity extends Serializable { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return information value + * @return information value, or 0 if count = 0 */ @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 97149a99ead59..698a1a2a8e899 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -31,7 +31,7 @@ object Variance extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = @@ -43,9 +43,13 @@ object Variance extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + if (count == 0) { + return 0 + } val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index bf692ca8c4bd7..3d3406b5d5f22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,8 @@ import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: - * Model to store the decision tree parameters + * Decision tree model for classification or regression. + * This model stores the decision tree structure and parameters. * @param topNode root node * @param algo algorithm type -- classification or regression */ @@ -50,4 +51,32 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } + + /** + * Get number of nodes in tree, including leaf nodes. + */ + def numNodes: Int = { + 1 + topNode.numDescendants + } + + /** + * Get depth of tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + def depth: Int = { + topNode.subtreeDepth + } + + /** + * Print full model. + */ + override def toString: String = algo match { + case Classification => + s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2) + case Regression => + s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2) + case _ => throw new IllegalArgumentException( + s"DecisionTreeModel given unknown algo parameter: $algo.") + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 682f213f411a7..944f11c2c2e4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -91,4 +91,60 @@ class Node ( } } } + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int = { + if (isLeaf) { + 0 + } else { + 2 + leftNode.get.numDescendants + rightNode.get.numDescendants + } + } + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. + */ + private[tree] def subtreeDepth: Int = { + if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) + } + } + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String = { + + def splitToString(split: Split, left: Boolean): String = { + split.featureType match { + case Continuous => if (left) { + s"(feature ${split.feature} <= ${split.threshold})" + } else { + s"(feature ${split.feature} > ${split.threshold})" + } + case Categorical => if (left) { + s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + } else { + s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + } + } + } + val prefix: String = " " * indentFactor + if (isLeaf) { + prefix + s"Predict: $predict\n" + } else { + prefix + s"If ${splitToString(split.get, left=true)}\n" + + leftNode.get.subtreeToString(indentFactor + 1) + + prefix + s"Else ${splitToString(split.get, left=false)}\n" + + rightNode.get.subtreeToString(indentFactor + 1) + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala deleted file mode 100644 index e25bf18b780bf..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.util - -/** Trait for label parsers. */ -private trait LabelParser extends Serializable { - /** Parses a string label into a double label. */ - def parse(labelString: String): Double -} - -/** Factory methods for label parsers. */ -private object LabelParser { - def getInstance(multiclass: Boolean): LabelParser = { - if (multiclass) MulticlassLabelParser else BinaryLabelParser - } -} - -/** - * Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5, - * or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling. - */ -private object BinaryLabelParser extends LabelParser { - /** Gets the default instance of BinaryLabelParser. */ - def getInstance(): LabelParser = this - - /** - * Parses the input label into positive (1.0) if the value is greater than 0.5, - * or negative (0.0) otherwise. - */ - override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 -} - -/** - * Label parser for multiclass labels, which converts the input label to double. - */ -private object MulticlassLabelParser extends LabelParser { - /** Gets the default instance of MulticlassLabelParser. */ - def getInstance(): LabelParser = this - - override def parse(labelString: String): Double = labelString.toDouble -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 30de24ad89f98..dc10a194783ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -55,7 +55,6 @@ object MLUtils { * * @param sc Spark context * @param path file or directory path in any Hadoop-supported file system URI - * @param labelParser parser for labels * @param numFeatures number of features, which will be determined from the input data if a * nonpositive value is given. This is useful when the dataset is already split * into multiple files and you want to load them separately, because some @@ -64,10 +63,9 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] */ - private def loadLibSVMFile( + def loadLibSVMFile( sc: SparkContext, path: String, - labelParser: LabelParser, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = { val parsed = sc.textFile(path, minPartitions) @@ -75,7 +73,7 @@ object MLUtils { .filter(line => !(line.isEmpty || line.startsWith("#"))) .map { line => val items = line.split(' ') - val label = labelParser.parse(items.head) + val label = items.head.toDouble val (indices, values) = items.tail.map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. @@ -102,64 +100,46 @@ object MLUtils { // Convenient methods for `loadLibSVMFile`. - /** - * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint]. - * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR. - * Each line represents a labeled sparse feature vector using the following format: - * {{{label index1:value1 index2:value2 ...}}} - * where the indices are one-based and in ascending order. - * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]], - * where the feature indices are converted to zero-based. - * - * @param sc Spark context - * @param path file or directory path in any Hadoop-supported file system URI - * @param multiclass whether the input labels contain more than two classes. If false, any label - * with value greater than 0.5 will be mapped to 1.0, or 0.0 otherwise. So it - * works for both +1/-1 and 1/0 cases. If true, the double value parsed directly - * from the label string will be used as the label value. - * @param numFeatures number of features, which will be determined from the input data if a - * nonpositive value is given. This is useful when the dataset is already split - * into multiple files and you want to load them separately, because some - * features may not present in certain files, which leads to inconsistent - * feature dimensions. - * @param minPartitions min number of partitions - * @return labeled data stored as an RDD[LabeledPoint] - */ - def loadLibSVMFile( + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") + def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, LabelParser.getInstance(multiclass), numFeatures, minPartitions) + loadLibSVMFile(sc, path, numFeatures, minPartitions) /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. */ + def loadLibSVMFile( + sc: SparkContext, + path: String, + numFeatures: Int): RDD[LabeledPoint] = + loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean, numFeatures: Int): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass, numFeatures, sc.defaultMinPartitions) + loadLibSVMFile(sc, path, numFeatures) - /** - * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the number of features - * determined automatically and the default number of partitions. - */ + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass, -1, sc.defaultMinPartitions) + loadLibSVMFile(sc, path) /** * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. */ def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass = false, -1, sc.defaultMinPartitions) + loadLibSVMFile(sc, path, -1) /** * Save labeled data in LIBSVM format. diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java new file mode 100644 index 0000000000000..e8d99f4ae43ae --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; + +public class JavaTfIdfSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaTfIdfSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void tfIdf() { + // The tests are to check Java compatibility. + HashingTF tf = new HashingTF(); + JavaRDD> documents = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("this is a sentence".split(" ")), + Lists.newArrayList("this is another sentence".split(" ")), + Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD termFreqs = tf.transform(documents); + termFreqs.collect(); + IDF idf = new IDF(); + JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); + List localTfIdfs = tfIdfs.collect(); + int indexOfThis = tf.indexOf("this"); + for (Vector v: localTfIdfs) { + Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index bf2365f82044c..f6ca9643227f8 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -20,6 +20,11 @@ import java.io.Serializable; import java.util.List; +import scala.Tuple2; +import scala.Tuple3; + +import org.jblas.DoubleMatrix; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -28,8 +33,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.jblas.DoubleMatrix; - public class JavaALSSuite implements Serializable { private transient JavaSparkContext sc; @@ -44,21 +47,28 @@ public void tearDown() { sc = null; } - static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, - DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { + static void validatePrediction( + MatrixFactorizationModel model, + int users, + int products, + int features, + DoubleMatrix trueRatings, + double matchThreshold, + boolean implicitPrefs, + DoubleMatrix truePrefs) { DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); + List> userFeatures = model.userFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { - for (scala.Tuple2 userFeature : userFeatures) { + for (Tuple2 userFeature : userFeatures) { predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); } } DoubleMatrix predictedP = new DoubleMatrix(products, features); - List> productFeatures = + List> productFeatures = model.productFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { - for (scala.Tuple2 productFeature : productFeatures) { + for (Tuple2 productFeature : productFeatures) { predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); } } @@ -75,7 +85,8 @@ static void validatePrediction(MatrixFactorizationModel model, int users, int pr } } } else { - // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests) + // For implicit prefs we use the confidence-weighted RMSE to test + // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; for (int u = 0; u < users; ++u) { @@ -100,7 +111,7 @@ public void runALSUsingStaticMethods() { int iterations = 15; int users = 50; int products = 100; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); @@ -114,14 +125,14 @@ public void runALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) - .setIterations(iterations) - .run(data.rdd()); + .setIterations(iterations) + .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); } @@ -131,7 +142,7 @@ public void runImplicitALSUsingStaticMethods() { int iterations = 15; int users = 80; int products = 160; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -145,7 +156,7 @@ public void runImplicitALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -163,12 +174,42 @@ public void runImplicitALSWithNegativeWeight() { int iterations = 15; int users = 80; int products = 160; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, true); JavaRDD data = sc.parallelize(testData._1()); - MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .setSeed(8675309L) + .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); } + @Test + public void runRecommend() { + int features = 5; + int iterations = 10; + int users = 200; + int products = 50; + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true, false); + JavaRDD data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .setSeed(8675309L) + .run(data.rdd()); + validateRecommendations(model.recommendProducts(1, 10), 10); + validateRecommendations(model.recommendUsers(1, 20), 20); + } + + private static void validateRecommendations(Rating[] recommendations, int howMany) { + Assert.assertEquals(howMany, recommendations.length); + for (int i = 1; i < recommendations.length; i++) { + Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating()); + } + Assert.assertTrue(recommendations[0].rating() > 0.7); + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index d94cfa2fcec81..bd413a80f5107 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends FunSuite { @@ -59,10 +59,25 @@ class PythonMLLibAPISuite extends FunSuite { } test("double serialization") { - for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) { + for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { val bytes = py.serializeDouble(x) val deser = py.deserializeDouble(bytes) - assert(x === deser) + // We use `equals` here for comparison because we cannot use `==` for NaN + assert(x.equals(deser)) } } + + test("matrix to 2D array") { + val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) + val matrix = Matrices.dense(2, 3, values) + val arr = py.to2dArray(matrix) + val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8)) + assert(arr === expected) + + // Test conversion for empty matrix + val empty = Array[Double]() + val emptyMatrix = Matrices.dense(0, 0, empty) + val empty2D = py.to2dArray(emptyMatrix) + assert(empty2D === Array[Array[Double]]()) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 3f6ff859374c7..da7c633bbd2af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ object LogisticRegressionSuite { @@ -81,9 +82,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD) // Test the weights - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + assert(model.weights(0) ~== -1.52 relTol 0.01) + assert(model.intercept ~== 2.00 relTol 0.01) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -113,9 +113,9 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD, initialWeights) - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + // Test the weights + assert(model.weights(0) ~== -1.50 relTol 0.01) + assert(model.intercept ~== 1.97 relTol 0.01) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 34bc4537a7b3a..afa1f79b95a12 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,8 +21,9 @@ import scala.util.Random import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class KMeansSuite extends FunSuite with LocalSparkContext { @@ -41,26 +42,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // centered at the mean of the points var model = KMeans.train(data, k = 1, maxIterations = 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train( data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) } test("no distinct points") { @@ -104,26 +105,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { var model = KMeans.train(data, k = 1, maxIterations = 1) assert(model.clusterCenters.size === 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) } test("single cluster with sparse data") { @@ -149,31 +150,39 @@ class KMeansSuite extends FunSuite with LocalSparkContext { val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0))) var model = KMeans.train(data, k = 1, maxIterations = 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) data.unpersist() } test("k-means|| initialization") { + + case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { + @Override def compare(that: VectorWithCompare): Int = { + if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + } + } + val points = Seq( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), @@ -188,15 +197,19 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // unselected point as long as it hasn't yet selected all of them var model = KMeans.train(rdd, k = 5, maxIterations = 1) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) // Iterations of Lloyd's should not change the answer either model = KMeans.train(rdd, k = 5, maxIterations = 10) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) // Neither should more runs model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) } test("two clusters") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 1c9844f289fe0..994e0feb8629e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -20,27 +20,28 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 - assert(AreaUnderCurve.of(curve) === auc) + assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) == auc) + assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5) } test("auc of an empty curve") { val curve = Seq.empty[(Double, Double)] - assert(AreaUnderCurve.of(curve) === 0.0) + assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) === 0.0) + assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } test("auc of a curve with a single point") { val curve = Seq((1.0, 1.0)) - assert(AreaUnderCurve.of(curve) === 0.0) + assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) === 0.0) + assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 94db1dc183230..a733f88b60b80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -20,25 +20,14 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals +import org.apache.spark.mllib.util.TestingUtils._ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { - // TODO: move utility functions to TestingUtils. + def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 - def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = { - actual.zip(expected).forall { case (x1, x2) => - x1.almostEquals(x2) - } - } - - def elementsAlmostEqual( - actual: Seq[(Double, Double)], - expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = { - actual.zip(expected).forall { case ((x1, y1), (x2, y2)) => - x1.almostEquals(x2) && y1.almostEquals(y2) - } - } + def cond2(x: ((Double, Double), (Double, Double))): Boolean = + (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( @@ -57,16 +46,17 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) val pr = recall.zip(precision) val prCurve = Seq((0.0, 1.0)) ++ pr - val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold)) - assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve)) - assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve))) - assert(elementsAlmostEqual(metrics.pr().collect(), prCurve)) - assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve))) - assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))) - assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))) - assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision))) - assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall))) + + assert(metrics.thresholds().collect().zip(threshold).forall(cond1)) + assert(metrics.roc().collect().zip(rocCurve).forall(cond2)) + assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5) + assert(metrics.pr().collect().zip(prCurve).forall(cond2)) + assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5) + assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2)) + assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2)) + assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2)) + assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala new file mode 100644 index 0000000000000..a599e0d938569 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.LocalSparkContext + +class HashingTFSuite extends FunSuite with LocalSparkContext { + + test("hashing tf on a single doc") { + val hashingTF = new HashingTF(1000) + val doc = "a a b b c d".split(" ") + val n = hashingTF.numFeatures + val termFreqs = Seq( + (hashingTF.indexOf("a"), 2.0), + (hashingTF.indexOf("b"), 2.0), + (hashingTF.indexOf("c"), 1.0), + (hashingTF.indexOf("d"), 1.0)) + assert(termFreqs.map(_._1).forall(i => i >= 0 && i < n), + "index must be in range [0, #features)") + assert(termFreqs.map(_._1).toSet.size === 4, "expecting perfect hashing") + val expected = Vectors.sparse(n, termFreqs) + assert(hashingTF.transform(doc) === expected) + } + + test("hashing tf on an RDD") { + val hashingTF = new HashingTF + val localDocs: Seq[Seq[String]] = Seq( + "a a b b b c d".split(" "), + "a b c d a b c".split(" "), + "c b a c b a a".split(" ")) + val docs = sc.parallelize(localDocs, 2) + assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala new file mode 100644 index 0000000000000..78a2804ff204b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class IDFSuite extends FunSuite with LocalSparkContext { + + test("idf") { + val n = 4 + val localTermFrequencies = Seq( + Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(n, Array(1), Array(1.0)) + ) + val m = localTermFrequencies.size + val termFrequencies = sc.parallelize(localTermFrequencies, 2) + val idf = new IDF + intercept[IllegalStateException] { + idf.idf() + } + intercept[IllegalStateException] { + idf.transform(termFrequencies) + } + idf.fit(termFrequencies) + val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((m.toDouble + 1.0) / (x + 1.0)) + }) + assert(idf.idf() ~== expected absTol 1e-12) + val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(tfidf.size === 3) + val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] + assert(tfidf0.indices === Array(1, 3)) + assert(Vectors.dense(tfidf0.values) ~== + Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12) + val tfidf1 = tfidf(1L).asInstanceOf[DenseVector] + assert(Vectors.dense(tfidf1.values) ~== + Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12) + val tfidf2 = tfidf(2L).asInstanceOf[SparseVector] + assert(tfidf2.indices === Array(1)) + assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index dfb2eb7f0d14e..bf040110e228b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -126,19 +127,14 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD( dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept) - def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { - math.abs(x - y) / (math.abs(y) + 1e-15) < tol - } - - assert(compareDouble( - loss1(0), - loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + - math.pow(initialWeightsWithIntercept(1), 2)) / 2), + assert( + loss1(0) ~= (loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + + math.pow(initialWeightsWithIntercept(1), 2)) / 2) absTol 1E-5, """For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""") assert( - compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) && - compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)), + (newWeights1(0) ~= (newWeights0(0) - initialWeightsWithIntercept(0)) absTol 1E-5) && + (newWeights1(1) ~= (newWeights0(1) - initialWeightsWithIntercept(1)) absTol 1E-5), "The different between newWeights with/without regularization " + "should be initialWeightsWithIntercept.") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index ff414742e8393..5f4c24115ac80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { @@ -49,10 +50,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { lazy val dataRDD = sc.parallelize(data, 2).cache() - def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { - math.abs(x - y) / (math.abs(y) + 1e-15) < tol - } - test("LBFGS loss should be decreasing and match the result of Gradient Descent.") { val regParam = 0 @@ -126,15 +123,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { miniBatchFrac, initialWeightsWithIntercept) - assert(compareDouble(lossGD(0), lossLBFGS(0)), + assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5, "The first losses of LBFGS and GD should be the same.") // The 2% difference here is based on observation, but is not theoretically guaranteed. - assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02), + assert(lossGD.last ~= lossLBFGS.last relTol 0.02, "The last losses of LBFGS and GD should be within 2% difference.") - assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && - compareDouble(weightLBFGS(1), weightGD(1), 0.02), + assert( + (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } @@ -226,8 +223,8 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { initialWeightsWithIntercept) // for class LBFGS and the optimize method, we only look at the weights - assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && - compareDouble(weightLBFGS(1), weightGD(1), 0.02), + assert( + (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index bbf385229081a..b781a6aed9a8c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -21,7 +21,9 @@ import scala.util.Random import org.scalatest.FunSuite -import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas} +import org.jblas.{DoubleMatrix, SimpleBlas} + +import org.apache.spark.mllib.util.TestingUtils._ class NNLSSuite extends FunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ @@ -73,7 +75,7 @@ class NNLSSuite extends FunSuite { val ws = NNLS.createWorkspace(n) val x = NNLS.solve(ata, atb, ws) for (i <- 0 until n) { - assert(Math.abs(x(i) - goodx(i)) < 1e-3) + assert(x(i) ~== goodx(i) absTol 1E-3) assert(x(i) >= 0) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 3f3b10dfff35e..27a19f793242b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,4 +46,22 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val expected = data.flatMap(x => x).sliding(3).toList assert(sliding.collect().toList === expected) } + + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index bce4251426df7..a3f76f77a5dcc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -31,6 +31,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) val data = Seq( Vectors.dense(1.0, 0.0, 0.0, -2.0), Vectors.dense(4.0, 5.0, 0.0, 3.0), @@ -46,6 +47,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val p1 = Statistics.corr(x, y, "pearson") assert(approxEqual(expected, default)) assert(approxEqual(expected, p1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val p2 = Statistics.corr(x1, y1) + assert(approxEqual(expected, p2)) + } + + // RDD of zero variance + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z).isNaN()) } test("corr(x, y) spearman") { @@ -54,6 +67,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val expected = 0.5 val s1 = Statistics.corr(x, y, "spearman") assert(approxEqual(expected, s1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val s2 = Statistics.corr(x1, y1, "spearman") + assert(approxEqual(expected, s2)) + } + + // RDD of zero variance => zero variance in ranks + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z, "spearman").isNaN()) } test("corr(X) default, pearson") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 4b7b019d820b4..db13f142df517 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -89,15 +89,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.dense(-1.0, 0.0, 6.0)) .add(Vectors.dense(3.0, -3.0, 0.0)) - assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") assert(summarizer.count === 2) } @@ -107,15 +107,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0)))) .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0)))) - assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") assert(summarizer.count === 2) } @@ -129,17 +129,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.dense(1.7, -0.6, 0.0)) .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))) - assert(summarizer.mean.almostEquals( - Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals( - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") assert(summarizer.count === 6) } @@ -157,17 +157,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { val summarizer = summarizer1.merge(summarizer2) - assert(summarizer.mean.almostEquals( - Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals( - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") assert(summarizer.count === 6) } @@ -186,24 +186,24 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer) assert(summarizer3.count === 0) - assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + assert(summarizer1.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") - assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + assert(summarizer2.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") - assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + assert(summarizer1.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") - assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + assert(summarizer2.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") - assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + assert(summarizer1.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") - assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + assert(summarizer2.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") - assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + assert(summarizer1.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + assert(summarizer2.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + assert(summarizer1.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") - assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5961a618c59d9..546a132559326 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -31,6 +30,30 @@ import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { + def validateClassifier( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy) + } + + def validateRegressor( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + (prediction - expected.label) * (prediction - expected.label) + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -50,7 +73,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) @@ -130,7 +153,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) @@ -236,7 +259,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("extract categories from a number for multiclass classification") { val l = DecisionTree.extractMultiClassCategories(13, 10) assert(l.length === 3) - assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) + assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } test("split and bin calculations for unordered categorical variables with multiclass " + @@ -247,7 +270,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -341,7 +364,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) @@ -397,7 +420,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, numClassesForClassification = 2, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -413,7 +436,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) assert(stats.predict === 1) - assert(stats.prob == 0.6) + assert(stats.prob === 0.6) assert(stats.impurity > 0.2) } @@ -424,7 +447,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Regression, Variance, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) @@ -439,10 +462,27 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict == 0.6) + assert(stats.predict === 0.6) assert(stats.impurity > 0.2) } + test("regression stump with categorical variables of arity 2") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + + val model = DecisionTree.train(rdd, strategy) + validateRegressor(model, arr, 0.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + test("stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) @@ -460,7 +500,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -483,7 +522,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -507,7 +545,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -531,7 +568,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -587,9 +623,68 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + } + + test("stump with 1 continuous variable for binary classification, to check off-by-1 error") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) + arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + + test("stump with 2 continuous variables for binary classification") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + assert(model.topNode.split.get.feature === 1) + } + + test("stump with categorical variables for multiclass classification, with just enough bins") { + val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -600,14 +695,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.categories.length === 1) assert(bestSplit.categories.contains(1)) assert(bestSplit.featureType === Categorical) + val gain = bestSplits(0)._2 + assert(gain.leftImpurity === 0) + assert(gain.rightImpurity === 0) } test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -625,9 +727,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -644,7 +750,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index c14870fb969a8..8ef2bb1bf6a78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -63,9 +63,9 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { test("loadLibSVMFile") { val lines = """ - |+1 1:1.0 3:2.0 5:3.0 - |-1 - |-1 2:4.0 4:5.0 6:6.0 + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 """.stripMargin val tempDir = Files.createTempDir() tempDir.deleteOnExit() @@ -73,7 +73,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString - val pointsWithNumFeatures = loadLibSVMFile(sc, path, multiclass = false, 6).collect() + val pointsWithNumFeatures = loadLibSVMFile(sc, path, 6).collect() val pointsWithoutNumFeatures = loadLibSVMFile(sc, path).collect() for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { @@ -86,11 +86,11 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) } - val multiclassPoints = loadLibSVMFile(sc, path, multiclass = true).collect() + val multiclassPoints = loadLibSVMFile(sc, path).collect() assert(multiclassPoints.length === 3) assert(multiclassPoints(0).label === 1.0) - assert(multiclassPoints(1).label === -1.0) - assert(multiclassPoints(2).label === -1.0) + assert(multiclassPoints(1).label === 0.0) + assert(multiclassPoints(2).label === 0.0) Utils.deleteRecursively(tempDir) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 64b1ba7527183..29cc42d8cbea7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -18,28 +18,155 @@ package org.apache.spark.mllib.util import org.apache.spark.mllib.linalg.Vector +import org.scalatest.exceptions.TestFailedException object TestingUtils { + val ABS_TOL_MSG = " using absolute tolerance" + val REL_TOL_MSG = " using relative tolerance" + + /** + * Private helper function for comparing two values using relative tolerance. + * Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue, + * the relative tolerance is meaningless, so the exception will be raised to warn users. + */ + private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + val absX = math.abs(x) + val absY = math.abs(y) + val diff = math.abs(x - y) + if (x == y) { + true + } else if (absX < Double.MinPositiveValue || absY < Double.MinPositiveValue) { + throw new TestFailedException( + s"$x or $y is extremely close to zero, so the relative tolerance is meaningless.", 0) + } else { + diff < eps * math.min(absX, absY) + } + } + + /** + * Private helper function for comparing two values using absolute tolerance. + */ + private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + math.abs(x - y) < eps + } + + case class CompareDoubleRightSide( + fun: (Double, Double, Double) => Boolean, y: Double, eps: Double, method: String) + + /** + * Implicit class for comparing two double values using relative tolerance or absolute tolerance. + */ implicit class DoubleWithAlmostEquals(val x: Double) { - // An improved version of AlmostEquals would always divide by the larger number. - // This will avoid the problem of diving by zero. - def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = { - if(x == y) { - true - } else if(math.abs(x) > math.abs(y)) { - math.abs(x - y) / math.abs(x) < epsilon - } else { - math.abs(x - y) / math.abs(y) < epsilon + + /** + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareDoubleRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two values are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareDoubleRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two values are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareDoubleRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected $x and ${r.y} to be within ${r.eps}${r.method}.", 0) } + true } + + /** + * Throws exception when the difference of two values are within eps; otherwise, returns true. + */ + def !~==(r: CompareDoubleRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method}.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, + x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. + */ + def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, + x, eps, REL_TOL_MSG) + + override def toString = x.toString } + case class CompareVectorRightSide( + fun: (Vector, Vector, Double) => Boolean, y: Vector, eps: Double, method: String) + + /** + * Implicit class for comparing two vectors using relative tolerance or absolute tolerance. + */ implicit class VectorWithAlmostEquals(val x: Vector) { - def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = { - x.toArray.corresponds(y.toArray) { - _.almostEquals(_, epsilon) + + /** + * When the difference of two vectors are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareVectorRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two vectors are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareVectorRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two vectors are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareVectorRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) } + true } + + /** + * Throws exception when the difference of two vectors are within eps; otherwise, returns true. + */ + def !~==(r: CompareVectorRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( + (x: Vector, y: Vector, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps) + }, x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. Note that comparing against sparse vector + * with elements having value of zero will raise exception because it involves with + * comparing against zero. + */ + def relTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( + (x: Vector, y: Vector, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) + }, x, eps, REL_TOL_MSG) + + override def toString = x.toString } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala new file mode 100644 index 0000000000000..b0ecb33c28483 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.mllib.linalg.Vectors +import org.scalatest.FunSuite +import org.apache.spark.mllib.util.TestingUtils._ +import org.scalatest.exceptions.TestFailedException + +class TestingUtilsSuite extends FunSuite { + + test("Comparing doubles using relative error.") { + + assert(23.1 ~== 23.52 relTol 0.02) + assert(23.1 ~== 22.74 relTol 0.02) + assert(23.1 ~= 23.52 relTol 0.02) + assert(23.1 ~= 22.74 relTol 0.02) + assert(!(23.1 !~= 23.52 relTol 0.02)) + assert(!(23.1 !~= 22.74 relTol 0.02)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](23.1 !~== 23.52 relTol 0.02) + intercept[TestFailedException](23.1 !~== 22.74 relTol 0.02) + intercept[TestFailedException](23.1 ~== 23.63 relTol 0.02) + intercept[TestFailedException](23.1 ~== 22.34 relTol 0.02) + + assert(23.1 !~== 23.63 relTol 0.02) + assert(23.1 !~== 22.34 relTol 0.02) + assert(23.1 !~= 23.63 relTol 0.02) + assert(23.1 !~= 22.34 relTol 0.02) + assert(!(23.1 ~= 23.63 relTol 0.02)) + assert(!(23.1 ~= 22.34 relTol 0.02)) + + // Comparing against zero should fail the test and throw exception with message + // saying that the relative error is meaningless in this situation. + intercept[TestFailedException](0.1 ~== 0.0 relTol 0.032) + intercept[TestFailedException](0.1 ~= 0.0 relTol 0.032) + intercept[TestFailedException](0.1 !~== 0.0 relTol 0.032) + intercept[TestFailedException](0.1 !~= 0.0 relTol 0.032) + intercept[TestFailedException](0.0 ~== 0.1 relTol 0.032) + intercept[TestFailedException](0.0 ~= 0.1 relTol 0.032) + intercept[TestFailedException](0.0 !~== 0.1 relTol 0.032) + intercept[TestFailedException](0.0 !~= 0.1 relTol 0.032) + + // Comparisons of numbers very close to zero. + assert(10 * Double.MinPositiveValue ~== 9.5 * Double.MinPositiveValue relTol 0.01) + assert(10 * Double.MinPositiveValue !~== 11 * Double.MinPositiveValue relTol 0.01) + + assert(-Double.MinPositiveValue ~== 1.18 * -Double.MinPositiveValue relTol 0.012) + assert(-Double.MinPositiveValue ~== 1.38 * -Double.MinPositiveValue relTol 0.012) + } + + test("Comparing doubles using absolute error.") { + + assert(17.8 ~== 17.99 absTol 0.2) + assert(17.8 ~== 17.61 absTol 0.2) + assert(17.8 ~= 17.99 absTol 0.2) + assert(17.8 ~= 17.61 absTol 0.2) + assert(!(17.8 !~= 17.99 absTol 0.2)) + assert(!(17.8 !~= 17.61 absTol 0.2)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](17.8 !~== 17.99 absTol 0.2) + intercept[TestFailedException](17.8 !~== 17.61 absTol 0.2) + intercept[TestFailedException](17.8 ~== 18.01 absTol 0.2) + intercept[TestFailedException](17.8 ~== 17.59 absTol 0.2) + + assert(17.8 !~== 18.01 absTol 0.2) + assert(17.8 !~== 17.59 absTol 0.2) + assert(17.8 !~= 18.01 absTol 0.2) + assert(17.8 !~= 17.59 absTol 0.2) + assert(!(17.8 ~= 18.01 absTol 0.2)) + assert(!(17.8 ~= 17.59 absTol 0.2)) + + // Comparisons of numbers very close to zero, and both side of zeros + assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + } + + test("Comparing vectors using relative error.") { + + //Comparisons of two dense vectors + assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + assert(!(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)) + assert(!(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01)) + + // Should throw exception with message when test fails. + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + + // Comparing against zero should fail the test and throw exception with message + // saying that the relative error is meaningless in this situation. + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 0.01)) ~== Vectors.dense(Array(3.13, 0.0)) relTol 0.01) + + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 0.01)) ~== Vectors.sparse(2, Array(0), Array(3.13)) relTol 0.01) + + // Comparisons of two sparse vectors + assert(Vectors.dense(Array(3.1, 3.5)) ~== + Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01) + + assert(Vectors.dense(Array(3.1, 3.5)) !~== + Vectors.sparse(2, Array(0, 1), Array(3.135, 3.534)) relTol 0.01) + } + + test("Comparing vectors using absolute error.") { + + //Comparisons of two dense vectors + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~== + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~= + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~= + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) !~= + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)) + + assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) ~= + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) !~== + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) ~== + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + // Comparisons of two sparse vectors + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== + Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) !~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + // Comparisons of a dense vector and a sparse vector + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~== + Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== + Vectors.dense(Array(3.1, 1E-3, 2.4)) absTol 1E-6) + } +} diff --git a/pom.xml b/pom.xml index d2e6b3c0ed5a4..ae97bf03c53a2 100644 --- a/pom.xml +++ b/pom.xml @@ -100,6 +100,7 @@ external/twitter external/kafka external/flume + external/flume-sink external/zeromq external/mqtt examples @@ -113,6 +114,7 @@ spark 2.10.4 2.10 + 2.0.1 0.18.1 shaded-protobuf org.spark-project.akka @@ -252,9 +254,15 @@ 3.3.2 - commons-codec - commons-codec - 1.5 + commons-codec + commons-codec + 1.5 + + + org.apache.commons + commons-math3 + 3.3 + test com.google.code.findbugs @@ -818,6 +826,15 @@ -target ${java.version} + + + + org.scalamacros + paradise_${scala.version} + ${scala.macros.version} + + @@ -1139,5 +1156,15 @@ + + hive-thriftserver + + false + + + sql/hive-thriftserver + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5ff88f0dd1cac..537ca0dcf267d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -71,7 +71,12 @@ object MimaExcludes { "org.apache.spark.storage.TachyonStore.putValues") ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.flume.FlumeReceiver.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver.this") ) ++ Seq( // Ignore some private methods in ALS. ProblemFilters.exclude[MissingMethodProblem]( @@ -97,6 +102,14 @@ object MimaExcludes { "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) ++ + Seq ( // Package-private classes removed in SPARK-2341 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") ) case v if v.startsWith("1.0") => Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 62576f84dd031..a8bbd55861954 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,11 +30,12 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming, - streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = - Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql", - "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq").map(ProjectRef(buildLocation, _)) + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, + sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, + streamingTwitter, streamingZeromq) = + Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", + "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", + "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") @@ -100,7 +101,7 @@ object SparkBuild extends PomBuild { Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1))) - case _ => + case _ => } override val userPropertiesMap = System.getProperties.toMap @@ -156,10 +157,9 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - /* Enable Mima for all projects except spark, hive, catalyst, sql and repl */ // TODO: Add Sql to mima checks - allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)). - foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) + allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, + streamingFlumeSink).contains(x)).foreach(x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) @@ -167,12 +167,17 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Catalyst macro settings */ + enable(Catalyst.settings)(catalyst) + /* Spark SQL Core console settings */ enable(SQL.settings)(sql) /* Hive console settings */ enable(Hive.settings)(hive) + enable(Flume.settings)(streamingFlumeSink) + // TODO: move this to its upstream project. override def projectDefinitions(baseDirectory: File): Seq[Project] = { super.projectDefinitions(baseDirectory).map { x => @@ -183,10 +188,20 @@ object SparkBuild extends PomBuild { } -object SQL { +object Flume { + lazy val settings = sbtavro.SbtAvro.avroSettings +} +object Catalyst { lazy val settings = Seq( + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), + // Quasiquotes break compiling scala doc... + // TODO: Investigate fixing this. + sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) +} +object SQL { + lazy val settings = Seq( initialCommands in console := """ |import org.apache.spark.sql.catalyst.analysis._ @@ -201,7 +216,6 @@ object SQL { |import org.apache.spark.sql.test.TestSQLContext._ |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin ) - } object Hive { @@ -281,6 +295,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) .map(_.filterNot(_.getCanonicalPath.contains("network"))) + .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("executor"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) .map(_.filterNot(_.getCanonicalPath.contains("collection"))) @@ -301,7 +316,7 @@ object Unidoc { "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration", "mllib.tree.impurity", "mllib.tree.model", "mllib.util" ), - "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"), + "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" ) ) diff --git a/project/plugins.sbt b/project/plugins.sbt index d3ac4bf335e87..06d18e193076e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -24,3 +24,5 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.0") + +addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") diff --git a/python/lib/py4j-0.8.1-src.zip b/python/lib/py4j-0.8.1-src.zip deleted file mode 100644 index 2069a328d1f2e..0000000000000 Binary files a/python/lib/py4j-0.8.1-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.8.2.1-src.zip b/python/lib/py4j-0.8.2.1-src.zip new file mode 100644 index 0000000000000..5203b84d9119e Binary files /dev/null and b/python/lib/py4j-0.8.2.1-src.zip differ diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 312c75d112cbf..c58555fc9d2c5 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -49,6 +49,16 @@ Main entry point for accessing data stored in Apache Hive.. """ +# The following block allows us to import python's random instead of mllib.random for scripts in +# mllib that depend on top level pyspark packages, which transitively depend on python's random. +# Since Python's import logic looks for modules in the current package first, we eliminate +# mllib.random as a candidate for C{import random} by removing the first search path, the script's +# location, in order to force the loader to look in Python's top-level modules for C{random}. +import sys +s = sys.path.pop(0) +import random +sys.path.insert(0, s) + from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.sql import SQLContext diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 2204e9c9ca701..45d36e5d0e764 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -86,6 +86,7 @@ Exception:... """ +import select import struct import SocketServer import threading @@ -209,19 +210,38 @@ def addInPlace(self, value1, value2): class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + """ + This handler will keep polling updates from the same socket until the + server is shutdown. + """ + def handle(self): from pyspark.accumulators import _accumulatorRegistry - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + +class AccumulatorServer(SocketServer.TCPServer): + """ + A simple TCP server that intercepts shutdown() in order to interrupt + our continuous polling on the handler. + """ + server_shutdown = False + def shutdown(self): + self.server_shutdown = True + SocketServer.TCPServer.shutdown(self) def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 4fda2a9b950b8..68062483dedaa 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -560,8 +560,9 @@ class ItemGetterType(ctypes.Structure): ] - itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents - return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,)) + obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents + return self.save_reduce(operator.itemgetter, + obj.item if obj.nitems > 1 else (obj.item,)) if PyObject_HEAD: dispatch[operator.itemgetter] = save_itemgetter diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 830a6ee03f2a6..7b0f8d83aedc5 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -60,6 +60,7 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, @@ -378,7 +379,7 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None): + valueConverter=None, minSplits=None, batchSize=None): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -398,14 +399,18 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, @param valueConverter: @param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ minSplits = minSplits or min(self.defaultParallelism, 2) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, - keyConverter, valueConverter, minSplits) - return RDD(jrdd, self, PickleSerializer()) + keyConverter, valueConverter, minSplits, batchSize) + return RDD(jrdd, self, ser) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -425,14 +430,18 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -449,14 +458,18 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -476,14 +489,18 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -500,11 +517,15 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, - keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 8e3ad6b783b6c..c6ca6a75df746 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -101,7 +101,7 @@ def _serialize_double(d): """ Serialize a double (float or numpy.float64) into a mutually understood format. """ - if type(d) == float or type(d) == float64: + if type(d) == float or type(d) == float64 or type(d) == int or type(d) == long: d = float64(d) ba = bytearray(8) _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64) @@ -176,6 +176,10 @@ def _deserialize_double(ba, offset=0): True >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0 True + >>> _deserialize_double(_serialize_double(1)) == 1.0 + True + >>> _deserialize_double(_serialize_double(1L)) == 1.0 + True >>> x = sys.float_info.max >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x True diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 9e28dfbb9145d..2bbb9c3fca315 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -66,7 +66,8 @@ def predict(self, x): if margin > 0: prob = 1 / (1 + exp(-margin)) else: - prob = 1 - 1 / (1 + exp(margin)) + exp_margin = exp(margin) + prob = exp_margin / (1 + exp_margin) return 1 if prob > 0.5 else 0 diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 71f4ad1a8d44e..54720c2324ca6 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -255,4 +255,8 @@ def _test(): exit(-1) if __name__ == "__main__": + # remove current path from list of search paths to avoid importing mllib.random + # for C{import random}, which is done in an external dependency of pyspark during doctests. + import sys + sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py new file mode 100644 index 0000000000000..36e710dbae7a8 --- /dev/null +++ b/python/pyspark/mllib/random.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for random data generation. +""" + + +from pyspark.rdd import RDD +from pyspark.mllib._common import _deserialize_double, _deserialize_double_vector +from pyspark.serializers import NoOpSerializer + +class RandomRDDGenerators: + """ + Generator methods for creating RDDs comprised of i.i.d samples from + some distribution. + """ + + @staticmethod + def uniformRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the + uniform distribution on [0.0, 1.0]. + + To transform the distribution in the generated RDD from U[0.0, 1.0] + to U[a, b], use + C{RandomRDDGenerators.uniformRDD(sc, n, p, seed)\ + .map(lambda v: a + (b - a) * v)} + + >>> x = RandomRDDGenerators.uniformRDD(sc, 100).collect() + >>> len(x) + 100 + >>> max(x) <= 1.0 and min(x) >= 0.0 + True + >>> RandomRDDGenerators.uniformRDD(sc, 100, 4).getNumPartitions() + 4 + >>> parts = RandomRDDGenerators.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts == sc.defaultParallelism + True + """ + jrdd = sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) + uniform = RDD(jrdd, sc, NoOpSerializer()) + return uniform.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def normalRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d samples from the standard normal + distribution. + + To transform the distribution in the generated RDD from standard normal + to some other normal N(mean, sigma), use + C{RandomRDDGenerators.normal(sc, n, p, seed)\ + .map(lambda v: mean + sigma * v)} + + >>> x = RandomRDDGenerators.normalRDD(sc, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - 0.0) < 0.1 + True + >>> abs(stats.stdev() - 1.0) < 0.1 + True + """ + jrdd = sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) + normal = RDD(jrdd, sc, NoOpSerializer()) + return normal.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def poissonRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d samples from the Poisson + distribution with the input mean. + + >>> mean = 100.0 + >>> x = RandomRDDGenerators.poissonRDD(sc, mean, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + jrdd = sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) + poisson = RDD(jrdd, sc, NoOpSerializer()) + return poisson.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the uniform distribution on [0.0 1.0]. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDGenerators.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat.shape + (10, 10) + >>> mat.max() <= 1.0 and mat.min() >= 0.0 + True + >>> RandomRDDGenerators.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + 4 + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + uniform = RDD(jrdd, sc, NoOpSerializer()) + return uniform.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + @staticmethod + def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the standard normal distribution. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDGenerators.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - 0.0) < 0.1 + True + >>> abs(mat.std() - 1.0) < 0.1 + True + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + normal = RDD(jrdd, sc, NoOpSerializer()) + return normal.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + @staticmethod + def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the Poisson distribution with the input mean. + + >>> import numpy as np + >>> mean = 100.0 + >>> rdd = RandomRDDGenerators.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) + poisson = RDD(jrdd, sc, NoOpSerializer()) + return poisson.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py new file mode 100644 index 0000000000000..0a08a562d1f1f --- /dev/null +++ b/python/pyspark/mllib/stat.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for statistical functions in MLlib. +""" + +from pyspark.mllib._common import \ + _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ + _serialize_double, _serialize_double_vector, \ + _deserialize_double, _deserialize_double_matrix + +class Statistics(object): + + @staticmethod + def corr(x, y=None, method=None): + """ + Compute the correlation (matrix) for the input RDD(s) using the + specified method. + Methods currently supported: I{pearson (default), spearman}. + + If a single RDD of Vectors is passed in, a correlation matrix + comparing the columns in the input RDD is returned. Use C{method=} + to specify the method to be used for single RDD inout. + If two RDDs of floats are passed in, a single float is returned. + + >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) + >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) + >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) + >>> abs(Statistics.corr(x, y) - 0.6546537) < 1e-7 + True + >>> Statistics.corr(x, y) == Statistics.corr(x, y, "pearson") + True + >>> Statistics.corr(x, y, "spearman") + 0.5 + >>> from math import isnan + >>> isnan(Statistics.corr(x, zeros)) + True + >>> from linalg import Vectors + >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), + ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) + >>> Statistics.corr(rdd) + array([[ 1. , 0.05564149, nan, 0.40047142], + [ 0.05564149, 1. , nan, 0.91359586], + [ nan, nan, 1. , nan], + [ 0.40047142, 0.91359586, nan, 1. ]]) + >>> Statistics.corr(rdd, method="spearman") + array([[ 1. , 0.10540926, nan, 0.4 ], + [ 0.10540926, 1. , nan, 0.9486833 ], + [ nan, nan, 1. , nan], + [ 0.4 , 0.9486833 , nan, 1. ]]) + >>> try: + ... Statistics.corr(rdd, "spearman") + ... print "Method name as second argument without 'method=' shouldn't be allowed." + ... except TypeError: + ... pass + """ + sc = x.ctx + # Check inputs to determine whether a single value or a matrix is needed for output. + # Since it's legal for users to use the method name as the second argument, we need to + # check if y is used to specify the method name instead. + if type(y) == str: + raise TypeError("Use 'method=' to specify method name.") + if not y: + try: + Xser = _get_unmangled_double_vector_rdd(x) + except TypeError: + raise TypeError("corr called on a single RDD not consisted of Vectors.") + resultMat = sc._jvm.PythonMLLibAPI().corr(Xser._jrdd, method) + return _deserialize_double_matrix(resultMat) + else: + xSer = _get_unmangled_rdd(x, _serialize_double) + ySer = _get_unmangled_rdd(y, _serialize_double) + result = sc._jvm.PythonMLLibAPI().corr(xSer._jrdd, ySer._jrdd, method) + return result + + +def _test(): + import doctest + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index a707a9dcd5b49..d94900cefdb77 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -29,15 +29,18 @@ class MLUtils: Helper methods to load, save and pre-process data used in MLlib. """ + @deprecated @staticmethod def _parse_libsvm_line(line, multiclass): + return _parse_libsvm_line(line) + + @staticmethod + def _parse_libsvm_line(line): """ Parses a line in LIBSVM format into (label, indices, values). """ items = line.split(None) label = float(items[0]) - if not multiclass: - label = 1.0 if label > 0.5 else 0.0 nnz = len(items) - 1 indices = np.zeros(nnz, dtype=np.int32) values = np.zeros(nnz) @@ -64,8 +67,13 @@ def _convert_labeled_point_to_libsvm(p): " but got " % type(v)) return " ".join(items) + @deprecated @staticmethod def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + return loadLibSVMFile(sc, path, numFeatures, minPartitions) + + @staticmethod + def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): """ Loads labeled data in the LIBSVM format into an RDD of LabeledPoint. The LIBSVM format is a text-based format used by @@ -81,13 +89,6 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non @param sc: Spark context @param path: file or directory path in any Hadoop-supported file system URI - @param multiclass: whether the input labels contain more than - two classes. If false, any label with value - greater than 0.5 will be mapped to 1.0, or - 0.0 otherwise. So it works for both +1/-1 and - 1/0 cases. If true, the double value parsed - directly from the label string will be used - as the label value. @param numFeatures: number of features, which will be determined from the input data if a nonpositive value is given. This is useful when the dataset is @@ -105,7 +106,7 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() - >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect() + >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() >>> type(examples[0]) == LabeledPoint True @@ -124,7 +125,7 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non """ lines = sc.textFile(path, minPartitions) - parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l, multiclass)) + parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) if numFeatures <= 0: parsed.cache() numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b84d976114f0d..e8fcc900efb24 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -231,6 +231,13 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() + def _toPickleSerialization(self): + if (self._jrdd_deserializer == PickleSerializer() or + self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): + return self + else: + return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def id(self): """ A unique ID for this RDD (within its SparkContext). @@ -1030,6 +1037,113 @@ def first(self): """ return self.take(1)[0] + def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the new Hadoop OutputFormat API (mapreduce package). Keys/values are + converted for output using either user specified converters or, by default, + L{org.apache.spark.api.python.JavaToWritableConverter}. + + @param conf: Hadoop job configuration, passed in as a dict + @param keyConverter: (None by default) + @param valueConverter: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + keyConverter, valueConverter, True) + + def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, + keyConverter=None, valueConverter=None, conf=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the new Hadoop OutputFormat API (mapreduce package). Key and value types + will be inferred if not specified. Keys and values are converted for output using either + user specified converters or L{org.apache.spark.api.python.JavaToWritableConverter}. The + C{conf} is applied on top of the base Hadoop conf associated with the SparkContext + of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. + + @param path: path to Hadoop file + @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.IntWritable", None by default) + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.Text", None by default) + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: Hadoop job configuration, passed in as a dict (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) + + def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the old Hadoop OutputFormat API (mapred package). Keys/values are + converted for output using either user specified converters or, by default, + L{org.apache.spark.api.python.JavaToWritableConverter}. + + @param conf: Hadoop job configuration, passed in as a dict + @param keyConverter: (None by default) + @param valueConverter: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + keyConverter, valueConverter, False) + + def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, + keyConverter=None, valueConverter=None, conf=None, + compressionCodecClass=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the old Hadoop OutputFormat API (mapred package). Key and value types + will be inferred if not specified. Keys and values are converted for output using either + user specified converters or L{org.apache.spark.api.python.JavaToWritableConverter}. The + C{conf} is applied on top of the base Hadoop conf associated with the SparkContext + of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. + + @param path: path to Hadoop file + @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.IntWritable", None by default) + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.Text", None by default) + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: (None by default) + @param compressionCodecClass: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, + jconf, compressionCodecClass) + + def saveAsSequenceFile(self, path, compressionCodecClass=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the L{org.apache.hadoop.io.Writable} types that we convert from the + RDD's key and value types. The mechanism is as follows: + 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. + 2. Keys and values of this Java RDD are converted to Writables and written out. + + @param path: path to sequence file + @param compressionCodecClass: (None by default) + """ + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + path, compressionCodecClass) + def saveAsPickleFile(self, path, batchSize=10): """ Save this RDD as a SequenceFile of serialized objects. The serializer diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e89176823..9388ead5eaad3 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,12 +15,458 @@ # limitations under the License. # +import warnings + from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer from py4j.protocol import Py4JError -__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +__all__ = [ + "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", + "ShortType", "ArrayType", "MapType", "StructField", "StructType", + "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] + + +class PrimitiveTypeSingleton(type): + _instances = {} + + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + + +class StringType(object): + """Spark SQL StringType + + The data type representing string values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "StringType" + + +class BinaryType(object): + """Spark SQL BinaryType + + The data type representing bytearray values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "BinaryType" + + +class BooleanType(object): + """Spark SQL BooleanType + + The data type representing bool values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "BooleanType" + + +class TimestampType(object): + """Spark SQL TimestampType + + The data type representing datetime.datetime values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "TimestampType" + + +class DecimalType(object): + """Spark SQL DecimalType + + The data type representing decimal.Decimal values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "DecimalType" + + +class DoubleType(object): + """Spark SQL DoubleType + + The data type representing float values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "DoubleType" + + +class FloatType(object): + """Spark SQL FloatType + + The data type representing single precision floating-point values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "FloatType" + + +class ByteType(object): + """Spark SQL ByteType + + The data type representing int values with 1 singed byte. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "ByteType" + + +class IntegerType(object): + """Spark SQL IntegerType + + The data type representing int values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "IntegerType" + + +class LongType(object): + """Spark SQL LongType + + The data type representing long values. If the any value is beyond the range of + [-9223372036854775808, 9223372036854775807], please use DecimalType. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "LongType" + + +class ShortType(object): + """Spark SQL ShortType + + The data type representing int values with 2 signed bytes. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "ShortType" + + +class ArrayType(object): + """Spark SQL ArrayType + + The data type representing list values. + An ArrayType object comprises two fields, elementType (a DataType) and containsNull (a bool). + The field of elementType is used to specify the type of array elements. + The field of containsNull is used to specify if the array has None values. + + """ + def __init__(self, elementType, containsNull=False): + """Creates an ArrayType + + :param elementType: the data type of elements. + :param containsNull: indicates whether the list contains None values. + + >>> ArrayType(StringType) == ArrayType(StringType, False) + True + >>> ArrayType(StringType, True) == ArrayType(StringType) + False + """ + self.elementType = elementType + self.containsNull = containsNull + + def __repr__(self): + return "ArrayType(" + self.elementType.__repr__() + "," + \ + str(self.containsNull).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.elementType == other.elementType and + self.containsNull == other.containsNull) + + def __ne__(self, other): + return not self.__eq__(other) + + +class MapType(object): + """Spark SQL MapType + + The data type representing dict values. + A MapType object comprises three fields, + keyType (a DataType), valueType (a DataType) and valueContainsNull (a bool). + The field of keyType is used to specify the type of keys in the map. + The field of valueType is used to specify the type of values in the map. + The field of valueContainsNull is used to specify if values of this map has None values. + For values of a MapType column, keys are not allowed to have None values. + + """ + def __init__(self, keyType, valueType, valueContainsNull=True): + """Creates a MapType + :param keyType: the data type of keys. + :param valueType: the data type of values. + :param valueContainsNull: indicates whether values contains null values. + + >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True) + True + >>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType) + False + """ + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def __repr__(self): + return "MapType(" + self.keyType.__repr__() + "," + \ + self.valueType.__repr__() + "," + \ + str(self.valueContainsNull).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.keyType == other.keyType and + self.valueType == other.valueType and + self.valueContainsNull == other.valueContainsNull) + + def __ne__(self, other): + return not self.__eq__(other) + + +class StructField(object): + """Spark SQL StructField + + Represents a field in a StructType. + A StructField object comprises three fields, name (a string), dataType (a DataType), + and nullable (a bool). The field of name is the name of a StructField. The field of + dataType specifies the data type of a StructField. + The field of nullable specifies if values of a StructField can contain None values. + + """ + def __init__(self, name, dataType, nullable): + """Creates a StructField + :param name: the name of this field. + :param dataType: the data type of this field. + :param nullable: indicates whether values of this field can be null. + + >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) + True + >>> StructField("f1", StringType, True) == StructField("f2", StringType, True) + False + """ + self.name = name + self.dataType = dataType + self.nullable = nullable + + def __repr__(self): + return "StructField(" + self.name + "," + \ + self.dataType.__repr__() + "," + \ + str(self.nullable).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.name == other.name and + self.dataType == other.dataType and + self.nullable == other.nullable) + + def __ne__(self, other): + return not self.__eq__(other) + + +class StructType(object): + """Spark SQL StructType + + The data type representing namedtuple values. + A StructType object comprises a list of L{StructField}s. + + """ + def __init__(self, fields): + """Creates a StructType + + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True), + ... [StructField("f2", IntegerType, False)]]) + >>> struct1 == struct2 + False + """ + self.fields = fields + + def __repr__(self): + return "StructType(List(" + \ + ",".join([field.__repr__() for field in self.fields]) + "))" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.fields == other.fields) + + def __ne__(self, other): + return not self.__eq__(other) + + +def _parse_datatype_list(datatype_list_string): + """Parses a list of comma separated data types.""" + index = 0 + datatype_list = [] + start = 0 + depth = 0 + while index < len(datatype_list_string): + if depth == 0 and datatype_list_string[index] == ",": + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + start = index + 1 + elif datatype_list_string[index] == "(": + depth += 1 + elif datatype_list_string[index] == ")": + depth -= 1 + + index += 1 + + # Handle the last data type + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + return datatype_list + + +def _parse_datatype_string(datatype_string): + """Parses the given data type string. + + >>> def check_datatype(datatype): + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__()) + ... python_datatype = _parse_datatype_string(scala_datatype.toString()) + ... return datatype == python_datatype + >>> check_datatype(StringType()) + True + >>> check_datatype(BinaryType()) + True + >>> check_datatype(BooleanType()) + True + >>> check_datatype(TimestampType()) + True + >>> check_datatype(DecimalType()) + True + >>> check_datatype(DoubleType()) + True + >>> check_datatype(FloatType()) + True + >>> check_datatype(ByteType()) + True + >>> check_datatype(IntegerType()) + True + >>> check_datatype(LongType()) + True + >>> check_datatype(ShortType()) + True + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + True + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + True + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + True + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False)]) + >>> check_datatype(complex_structtype) + True + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + True + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, complex_arraytype, False) + >>> check_datatype(complex_maptype) + True + """ + left_bracket_index = datatype_string.find("(") + if left_bracket_index == -1: + # It is a primitive type. + left_bracket_index = len(datatype_string) + type_or_field = datatype_string[:left_bracket_index] + rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip() + if type_or_field == "StringType": + return StringType() + elif type_or_field == "BinaryType": + return BinaryType() + elif type_or_field == "BooleanType": + return BooleanType() + elif type_or_field == "TimestampType": + return TimestampType() + elif type_or_field == "DecimalType": + return DecimalType() + elif type_or_field == "DoubleType": + return DoubleType() + elif type_or_field == "FloatType": + return FloatType() + elif type_or_field == "ByteType": + return ByteType() + elif type_or_field == "IntegerType": + return IntegerType() + elif type_or_field == "LongType": + return LongType() + elif type_or_field == "ShortType": + return ShortType() + elif type_or_field == "ArrayType": + last_comma_index = rest_part.rfind(",") + containsNull = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + containsNull = False + elementType = _parse_datatype_string(rest_part[:last_comma_index].strip()) + return ArrayType(elementType, containsNull) + elif type_or_field == "MapType": + last_comma_index = rest_part.rfind(",") + valueContainsNull = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + valueContainsNull = False + keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip()) + return MapType(keyType, valueType, valueContainsNull) + elif type_or_field == "StructField": + first_comma_index = rest_part.find(",") + name = rest_part[:first_comma_index].strip() + last_comma_index = rest_part.rfind(",") + nullable = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + nullable = False + dataType = _parse_datatype_string( + rest_part[first_comma_index+1:last_comma_index].strip()) + return StructField(name, dataType, nullable) + elif type_or_field == "StructType": + # rest_part should be in the format like + # List(StructField(field1,IntegerType,false)). + field_list_string = rest_part[rest_part.find("(")+1:-1] + fields = _parse_datatype_list(field_list_string) + return StructType(fields) class SQLContext: @@ -47,12 +493,14 @@ def __init__(self, sparkContext, sqlContext=None): ... ValueError:... - >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, - ... "boolean" : True}]) + >>> from datetime import datetime + >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L, + ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1}, + ... "list": [1, 2, 3]}]) >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, - ... x.boolean)) + ... x.boolean, x.time, x.dict["a"], x.list)) >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True) + (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3]) """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -88,13 +536,13 @@ def inferSchema(self, rdd): >>> from array import array >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] + >>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}}, + ... {"f1" : [2, 3], "f2" : {"row2" : 2.0}}] True >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] + >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]}, + ... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}] True """ if (rdd.__class__ is SchemaRDD): @@ -107,6 +555,40 @@ def inferSchema(self, rdd): srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) return SchemaRDD(srdd, self) + def applySchema(self, rdd, schema): + """Applies the given schema to the given RDD of L{dict}s. + + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT * from table1") + >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, + ... {"field1" : 3, "field2": "row3"}] + True + >>> from datetime import datetime + >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0, + ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2}, + ... "list": [1, 2, 3]}]) + >>> schema = StructType([ + ... StructField("byte", ByteType(), False), + ... StructField("short", ShortType(), False), + ... StructField("float", FloatType(), False), + ... StructField("time", TimestampType(), False), + ... StructField("map", MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False), + ... StructField("list", ArrayType(ByteType(), False), False), + ... StructField("null", DoubleType(), True)]) + >>> srdd = sqlCtx.applySchema(rdd, schema).map( + ... lambda x: ( + ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null)) + >>> srdd.collect()[0] + (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + """ + jrdd = self._pythonToJavaMap(rdd._jrdd) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__()) + return SchemaRDD(srdd, self) + def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -137,10 +619,11 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path): - """Loads a text file storing one JSON object per line, - returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonFile(self, path, schema=None): + """Loads a text file storing one JSON object per line as a L{SchemaRDD}. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it goes through the entire dataset once to determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -149,8 +632,8 @@ def jsonFile(self, path): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() == [ @@ -158,16 +641,45 @@ def jsonFile(self, path): ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] True + >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + >>> srdd4.collect() == [ + ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, + ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, + ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] + True + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + >>> srdd6.collect() == [ + ... {"f1": "row1", "f2": None, "f3": None}, + ... {"f1": None, "f2": [10, 11], "f3": 10}, + ... {"f1": "row3", "f2": [], "f3": None}] + True """ - jschema_rdd = self._ssql_ctx.jsonFile(path) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonFile(path) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) + jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(jschema_rdd, self) - def jsonRDD(self, rdd): - """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonRDD(self, rdd, schema=None): + """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - >>> srdd = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it goes through the entire dataset once to determine the schema. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() == [ @@ -175,6 +687,29 @@ def jsonRDD(self, rdd): ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] True + >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + >>> srdd4.collect() == [ + ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, + ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, + ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] + True + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + >>> srdd6.collect() == [ + ... {"f1": "row1", "f2": None, "f3": None}, + ... {"f1": None, "f2": [10, 11], "f3": 10}, + ... {"f1": "row3", "f2": [], "f3": None}] + True """ def func(split, iterator): for x in iterator: @@ -184,7 +719,11 @@ def func(split, iterator): keyed = PipelinedRDD(rdd, func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(jschema_rdd, self) def sql(self, sqlQuery): @@ -276,6 +815,10 @@ class LocalHiveContext(HiveContext): 130091 """ + def __init__(self, sparkContext, sqlContext=None): + HiveContext.__init__(self, sparkContext, sqlContext) + warnings.warn("LocalHiveContext is deprecated. Use HiveContext instead.", DeprecationWarning) + def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) @@ -387,6 +930,10 @@ def saveAsTable(self, tableName): """Creates a new table with the contents of this SchemaRDD.""" self._jschema_rdd.saveAsTable(tableName) + def schema(self): + """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" + return _parse_datatype_string(self._jschema_rdd.schema().toString()) + def schemaString(self): """Returns the output schema in the tree format.""" return self._jschema_rdd.schemaString() @@ -509,8 +1056,8 @@ def _test(): {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) globs['nestedRdd2'] = sc.parallelize([ - {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)}, - {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}]) + {"f1": [[1, 2], [2, 3]], "f2": [1, 2]}, + {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 63cc5e9ad96fa..c29deb9574ea2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,6 +19,7 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ +from array import array from fileinput import input from glob import glob import os @@ -165,11 +166,17 @@ class TestAddFile(PySparkTestCase): def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that # this job fails due to `userlibrary` not being on the Python path: + # disable logging in log4j temporarily + log4j = self.sc._jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) def func(x): from userlibrary import UserClass return UserClass().hello() self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) + log4j.LogManager.getRootLogger().setLevel(old_level) + # Add the file, so the job should now succeed: path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") self.sc.addPyFile(path) @@ -278,6 +285,12 @@ def combOp(x, y): self.assertEqual(set([2]), sets[3]) self.assertEqual(set([1, 3]), sets[5]) + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + class TestIO(PySparkTestCase): @@ -315,6 +328,17 @@ def test_sequencefiles(self): ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] self.assertEqual(doubles, ed) + bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) + ebs = [(1, bytearray('aa', 'utf-8')), + (1, bytearray('aa', 'utf-8')), + (2, bytearray('aa', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (3, bytearray('cc', 'utf-8'))] + self.assertEqual(bytes, ebs) + text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", "org.apache.hadoop.io.Text", "org.apache.hadoop.io.Text").collect()) @@ -341,14 +365,34 @@ def test_sequencefiles(self): maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable").collect()) - em = [(1, {2.0: u'aa'}), + em = [(1, {}), (1, {3.0: u'bb'}), (2, {1.0: u'aa'}), (2, {1.0: u'cc'}), - (2, {3.0: u'bb'}), (3, {2.0: u'dd'})] self.assertEqual(maps, em) + # arrays get pickled to tuples by default + tuples = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable").collect()) + et = [(1, ()), + (2, (3.0, 4.0, 5.0)), + (3, (4.0, 5.0, 6.0))] + self.assertEqual(tuples, et) + + # with custom converters, primitive arrays can stay as arrays + arrays = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + ea = [(1, array('d')), + (2, array('d', [3.0, 4.0, 5.0])), + (3, array('d', [4.0, 5.0, 6.0]))] + self.assertEqual(arrays, ea) + clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) @@ -357,6 +401,12 @@ def test_sequencefiles(self): u'double': 54.0, u'int': 123, u'str': u'test1'}) self.assertEqual(clazz[0], ec) + unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + batchSize=1).collect()) + self.assertEqual(unbatched_clazz[0], ec) + def test_oldhadoop(self): basepath = self.tempdir.name ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", @@ -367,10 +417,11 @@ def test_oldhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - hello = self.sc.hadoopFile(hellopath, - "org.apache.hadoop.mapred.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text").collect() + oldconf = {"mapred.input.dir" : hellopath} + hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=oldconf).collect() result = [(0, u'Hello World!')] self.assertEqual(hello, result) @@ -385,10 +436,11 @@ def test_newhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - hello = self.sc.newAPIHadoopFile(hellopath, - "org.apache.hadoop.mapreduce.lib.input.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text").collect() + newconf = {"mapred.input.dir" : hellopath} + hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=newconf).collect() result = [(0, u'Hello World!')] self.assertEqual(hello, result) @@ -423,16 +475,267 @@ def test_bad_inputs(self): "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text")) - def test_converter(self): + def test_converters(self): + # use of custom converters basepath = self.tempdir.name maps = sorted(self.sc.sequenceFile( basepath + "/sftestdata/sfmap/", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable", - valueConverter="org.apache.spark.api.python.TestConverter").collect()) - em = [(1, [2.0]), (1, [3.0]), (2, [1.0]), (2, [1.0]), (2, [3.0]), (3, [2.0])] + keyConverter="org.apache.spark.api.python.TestInputKeyConverter", + valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) + em = [(u'\x01', []), + (u'\x01', [3.0]), + (u'\x02', [1.0]), + (u'\x02', [1.0]), + (u'\x03', [2.0])] + self.assertEqual(maps, em) + +class TestOutputFormat(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) + + def tearDown(self): + PySparkTestCase.tearDown(self) + shutil.rmtree(self.tempdir.name, ignore_errors=True) + + def test_sequencefiles(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") + ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) + self.assertEqual(ints, ei) + + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") + doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) + self.assertEqual(doubles, ed) + + ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] + self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") + bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) + self.assertEqual(bytes, ebs) + + et = [(u'1', u'aa'), + (u'2', u'bb'), + (u'3', u'cc')] + self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") + text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) + self.assertEqual(text, et) + + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") + bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) + self.assertEqual(bools, eb) + + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") + nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) + self.assertEqual(nulls, en) + + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") + maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect()) self.assertEqual(maps, em) + def test_oldhadoop(self): + basepath = self.tempdir.name + dict_data = [(1, {}), + (1, {"row1" : 1.0}), + (2, {"row2" : 2.0})] + self.sc.parallelize(dict_data).saveAsHadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable") + result = sorted(self.sc.hadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect()) + self.assertEqual(result, dict_data) + + conf = { + "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.MapWritable", + "mapred.output.dir" : basepath + "/olddataset/"} + self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) + input_conf = {"mapred.input.dir" : basepath + "/olddataset/"} + old_dataset = sorted(self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + conf=input_conf).collect()) + self.assertEqual(old_dataset, dict_data) + + def test_newhadoop(self): + basepath = self.tempdir.name + # use custom ArrayWritable types and converters to handle arrays + array_data = [(1, array('d')), + (1, array('d', [1.0, 2.0, 3.0])), + (2, array('d', [3.0, 4.0, 5.0]))] + self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + self.assertEqual(result, array_data) + + conf = {"mapreduce.outputformat.class" : + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.spark.api.python.DoubleArrayWritable", + "mapred.output.dir" : basepath + "/newdataset/"} + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(conf, + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + input_conf = {"mapred.input.dir" : basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", + conf=input_conf).collect()) + self.assertEqual(new_dataset, array_data) + + def test_newolderror(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/newolderror/saveAsHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/newolderror/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/badinputs/saveAsHadoopFile/", + "org.apache.hadoop.mapred.NotValidOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/badinputs/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + data = [(1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/converters/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", + valueConverter="org.apache.spark.api.python.TestOutputValueConverter") + converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) + expected = [(u'1', 3.0), + (u'2', 1.0), + (u'3', 2.0)] + self.assertEqual(converted, expected) + + def test_reserialization(self): + basepath = self.tempdir.name + x = range(1, 5) + y = range(1001, 1005) + data = zip(x, y) + rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) + rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") + result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) + self.assertEqual(result1, data) + + rdd.saveAsHadoopFile(basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") + result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) + self.assertEqual(result2, data) + + rdd.saveAsNewAPIHadoopFile(basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) + self.assertEqual(result3, data) + + conf4 = { + "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.dir" : basepath + "/reserialize/dataset"} + rdd.saveAsHadoopDataset(conf4) + result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) + self.assertEqual(result4, data) + + conf5 = {"mapreduce.outputformat.class" : + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.dir" : basepath + "/reserialize/newdataset"} + rdd.saveAsNewAPIHadoopDataset(conf5) + result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) + self.assertEqual(result5, data) + + def test_unbatched_save_and_read(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei, numSlices=len(ei)).saveAsSequenceFile( + basepath + "/unbatched/") + + unbatched_sequence = sorted(self.sc.sequenceFile(basepath + "/unbatched/", + batchSize=1).collect()) + self.assertEqual(unbatched_sequence, ei) + + unbatched_hadoopFile = sorted(self.sc.hadoopFile(basepath + "/unbatched/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + batchSize=1).collect()) + self.assertEqual(unbatched_hadoopFile, ei) + + unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile(basepath + "/unbatched/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + batchSize=1).collect()) + self.assertEqual(unbatched_newAPIHadoopFile, ei) + + oldconf = {"mapred.input.dir" : basepath + "/unbatched/"} + unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=oldconf, + batchSize=1).collect()) + self.assertEqual(unbatched_hadoopRDD, ei) + + newconf = {"mapred.input.dir" : basepath + "/unbatched/"} + unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=newconf, + batchSize=1).collect()) + self.assertEqual(unbatched_newAPIHadoopRDD, ei) + + def test_malformed_RDD(self): + basepath = self.tempdir.name + # non-batch-serialized RDD[[(K, V)]] should be rejected + data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] + rdd = self.sc.parallelize(data, numSlices=len(data)) + self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( + basepath + "/malformed/sequence")) class TestDaemon(unittest.TestCase): def connect(self, port): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 24d41b12d1b1a..2770f63059853 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -75,14 +75,19 @@ def main(infile, outfile): init_time = time.time() iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - except Exception as e: - # Write the error to stderr in addition to trying to pass it back to - # Java, in case it happened while serializing a record - print >> sys.stderr, "PySpark worker failed with exception:" - print >> sys.stderr, traceback.format_exc() - write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc(), outfile) - sys.exit(-1) + except Exception: + try: + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) + write_with_length(traceback.format_exc(), outfile) + outfile.flush() + except IOError: + # JVM close the socket + pass + except Exception: + # Write the error to stderr if it happened while serializing + print >> sys.stderr, "PySpark worker failed with exception:" + print >> sys.stderr, traceback.format_exc() + exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output diff --git a/python/run-tests b/python/run-tests index 29f755fc0dcd3..5049e15ce5f8a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -67,6 +67,7 @@ run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/linalg.py" +run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" diff --git a/repl/pom.xml b/repl/pom.xml index 4ebb1b82f0e8c..68f4504450778 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -55,6 +55,12 @@ ${project.version} runtime + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + org.eclipse.jetty jetty-server diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e1db4d5395ab9..42c7e511dc3f5 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -230,6 +230,20 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, case xs => xs find (_.name == cmd) } } + private var fallbackMode = false + + private def toggleFallbackMode() { + val old = fallbackMode + fallbackMode = !old + System.setProperty("spark.repl.fallback", fallbackMode.toString) + echo(s""" + |Switched ${if (old) "off" else "on"} fallback mode without restarting. + | If you have defined classes in the repl, it would + |be good to redefine them incase you plan to use them. If you still run + |into issues it would be good to restart the repl and turn on `:fallback` + |mode as first command. + """.stripMargin) + } /** Show the history */ lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { @@ -299,6 +313,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), + nullary("fallback", """ + |disable/enable advanced repl changes, these fix some issues but may introduce others. + |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) ) @@ -557,29 +574,27 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, if (isReplPower) powerCommands else Nil )*/ - val replayQuestionMessage = + private val replayQuestionMessage = """|That entry seems to have slain the compiler. Shall I replay |your session? I can re-run each line except the last one. |[y/n] """.trim.stripMargin - private val crashRecovery: PartialFunction[Throwable, Boolean] = { - case ex: Throwable => - echo(intp.global.throwableAsString(ex)) - - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true + private def crashRecovery(ex: Throwable): Boolean = { + echo(ex.toString) + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("\nUnrecoverable error.") + throw ex + case _ => + def fn(): Boolean = + try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + catch { case _: RuntimeException => false } + + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + true } /** The main read-eval-print loop for the repl. It calls @@ -605,7 +620,10 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } def innerLoop() { - if ( try processLine(readOneLine()) catch crashRecovery ) + val shouldContinue = try { + processLine(readOneLine()) + } catch {case t: Throwable => crashRecovery(t)} + if (shouldContinue) innerLoop() } innerLoop() diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 3842c291d0b7b..f60bbb4662af1 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -892,11 +892,16 @@ import org.apache.spark.util.Utils def definedTypeSymbol(name: String) = definedSymbols(newTypeName(name)) def definedTermSymbol(name: String) = definedSymbols(newTermName(name)) + val definedClasses = handlers.exists { + case _: ClassHandler => true + case _ => false + } + /** Code to import bound names from previous lines - accessPath is code to * append to objectName to access anything bound by request. */ val SparkComputedImports(importsPreamble, importsTrailer, accessPath) = - importsCode(referencedNames.toSet) + importsCode(referencedNames.toSet, definedClasses) /** Code to access a variable with the specified name */ def fullPath(vname: String) = { diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala index 9099e052f5796..193a42dcded12 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -108,8 +108,9 @@ trait SparkImports { * last one imported is actually usable. */ case class SparkComputedImports(prepend: String, append: String, access: String) + def fallback = System.getProperty("spark.repl.fallback", "false").toBoolean - protected def importsCode(wanted: Set[Name]): SparkComputedImports = { + protected def importsCode(wanted: Set[Name], definedClass: Boolean): SparkComputedImports = { /** Narrow down the list of requests from which imports * should be taken. Removes requests which cannot contribute * useful imports for the specified set of wanted names. @@ -124,8 +125,14 @@ trait SparkImports { // Single symbol imports might be implicits! See bug #1752. Rather than // try to finesse this, we will mimic all imports for now. def keepHandler(handler: MemberHandler) = handler match { - case _: ImportHandler => true - case x => x.definesImplicit || (x.definedNames exists wanted) + /* This case clause tries to "precisely" import only what is required. And in this + * it may miss out on some implicits, because implicits are not known in `wanted`. Thus + * it is suitable for defining classes. AFAIK while defining classes implicits are not + * needed.*/ + case h: ImportHandler if definedClass && !fallback => + h.importedNames.exists(x => wanted.contains(x)) + case _: ImportHandler => true + case x => x.definesImplicit || (x.definedNames exists wanted) } reqs match { @@ -182,7 +189,7 @@ trait SparkImports { // ambiguity errors will not be generated. Also, quote // the name of the variable, so that we don't need to // handle quoting keywords separately. - case x: ClassHandler => + case x: ClassHandler if !fallback => // I am trying to guess if the import is a defined class // This is an ugly hack, I am not 100% sure of the consequences. // Here we, let everything but "defined classes" use the import with val. diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index e2d8d5ff38dbe..c8763eb277052 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -256,6 +256,33 @@ class ReplSuite extends FunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + // We need to use local-cluster to test this case. + val output = runInterpreter("local-cluster[1,1,512]", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |import sqlContext.createSchemaRDD + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2632 importing a method from non serializable class and not using it.") { + val output = runInterpreter("local", + """ + |class TestClass() { def testMethod = 3 } + |val t = new TestClass + |import t.testMethod + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 147b506dd5ca3..5c87da5815b64 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -36,4 +36,4 @@ export SPARK_HOME=${SPARK_PREFIX} export SPARK_CONF_DIR="$SPARK_HOME/conf" # Add the PySpark classes to the PYTHONPATH: export PYTHONPATH=$SPARK_HOME/python:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH +export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH diff --git a/sbin/spark-executor b/sbin/spark-executor index 336549f29c9ce..3621321a9bc8d 100755 --- a/sbin/spark-executor +++ b/sbin/spark-executor @@ -20,7 +20,7 @@ FWDIR="$(cd `dirname $0`/..; pwd)" export PYTHONPATH=$FWDIR/python:$PYTHONPATH -export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH +export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH echo "Running spark-executor with framework dir = $FWDIR" exec $FWDIR/bin/spark-class org.apache.spark.executor.MesosExecutorBackend diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh new file mode 100755 index 0000000000000..8398e6f19b511 --- /dev/null +++ b/sbin/start-thriftserver.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL Thrift server + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-thriftserver [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6decde3fcd62d..54fa96baa1e18 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -32,14 +32,23 @@ Spark Project Catalyst http://spark.apache.org/ - catalyst + catalyst + + org.scala-lang + scala-compiler + org.scala-lang scala-reflect + + org.scalamacros + quasiquotes_${scala.binary.version} + ${scala.macros.version} + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5a55be1e51558..0d26b52a84695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -85,6 +85,26 @@ object ScalaReflection { case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } + def typeOfObject: PartialFunction[Any, DataType] = { + // The data type can be determined without ambiguity. + case obj: BooleanType.JvmType => BooleanType + case obj: BinaryType.JvmType => BinaryType + case obj: StringType.JvmType => StringType + case obj: ByteType.JvmType => ByteType + case obj: ShortType.JvmType => ShortType + case obj: IntegerType.JvmType => IntegerType + case obj: LongType.JvmType => LongType + case obj: FloatType.JvmType => FloatType + case obj: DoubleType.JvmType => DoubleType + case obj: DecimalType.JvmType => DecimalType + case obj: TimestampType.JvmType => TimestampType + case null => NullType + // For other cases, there is no obvious mapping from the type of the given object to a + // Catalyst data type. A user should provide his/her specific rules + // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of + // objects and then compose the user-defined PartialFunction with this one. + } + implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index a34b236c8ac6a..2c73a80f64ebf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -210,21 +210,21 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } | "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } - protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { + protected lazy val joinedRelation: Parser[LogicalPlan] = + relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { case r1 ~ jt ~ _ ~ r2 ~ cond => Join(r1, r2, joinType = jt.getOrElse(Inner), cond) - } + } - protected lazy val joinConditions: Parser[Expression] = - ON ~> expression + protected lazy val joinConditions: Parser[Expression] = + ON ~> expression - protected lazy val joinType: Parser[JoinType] = - INNER ^^^ Inner | - LEFT ~ SEMI ^^^ LeftSemi | - LEFT ~ opt(OUTER) ^^^ LeftOuter | - RIGHT ~ opt(OUTER) ^^^ RightOuter | - FULL ~ opt(OUTER) ^^^ FullOuter + protected lazy val joinType: Parser[JoinType] = + INNER ^^^ Inner | + LEFT ~ SEMI ^^^ LeftSemi | + LEFT ~ opt(OUTER) ^^^ LeftOuter | + RIGHT ~ opt(OUTER) ^^^ RightOuter | + FULL ~ opt(OUTER) ^^^ FullOuter protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02bdb64f308a5..74c0104e5b17f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -159,7 +159,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) - if !filter.resolved && aggregate.resolved && containsAggregate(havingCondition) => { + if aggregate.resolved && containsAggregate(havingCondition) => { val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 7abeb032964e1..a0e25775da6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.{errors, trees} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.BaseRelation +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode /** @@ -36,7 +36,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str case class UnresolvedRelation( databaseName: Option[String], tableName: String, - alias: Option[String] = None) extends BaseRelation { + alias: Option[String] = None) extends LeafNode { override def output = Nil override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5c8c810d9135a..f44521d6381c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -202,7 +202,7 @@ package object dsl { // Protobuf terminology def required = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a) + def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 9ce1f01056462..f38f99569f207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,72 +17,37 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.trees /** * A bound reference points to a specific slot in the input tuple, allowing the actual value * to be retrieved more efficiently. However, since operations like column pruning can change * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ -case class BoundReference(ordinal: Int, baseReference: Attribute) - extends Attribute with trees.LeafNode[Expression] { +case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends Expression with trees.LeafNode[Expression] { type EvaluatedType = Any - override def nullable = baseReference.nullable - override def dataType = baseReference.dataType - override def exprId = baseReference.exprId - override def qualifiers = baseReference.qualifiers - override def name = baseReference.name - - override def newInstance = BoundReference(ordinal, baseReference.newInstance) - override def withNullability(newNullability: Boolean) = - BoundReference(ordinal, baseReference.withNullability(newNullability)) - override def withQualifiers(newQualifiers: Seq[String]) = - BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) + override def references = Set.empty - override def toString = s"$baseReference:$ordinal" + override def toString = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } -/** - * Used to denote operators that do their own binding of attributes internally. - */ -trait NoBind { self: trees.TreeNode[_] => } - -class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { - import BindReferences._ - - def apply(plan: TreeNode): TreeNode = { - plan.transform { - case n: NoBind => n.asInstanceOf[TreeNode] - case leafNode if leafNode.children.isEmpty => leafNode - case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => - bindReference(e, unaryNode.children.head.output) - } - } - } -} - object BindReferences extends Logging { def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) if (ordinal == -1) { - // TODO: This fallback is required because some operators (such as ScriptTransform) - // produce new attributes that can't be bound. Likely the right thing to do is remove - // this rule and require all operators to explicitly bind to the input schema that - // they specify. - logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") - a + sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } else { - BoundReference(ordinal, a) + BoundReference(ordinal, a.dataType, a.nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 2c71d2c7b3563..8fc5896974438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions + /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. + * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ -class Projection(expressions: Seq[Expression]) extends (Row => Row) { +class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) @@ -40,25 +41,25 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { } /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of th - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. - * - * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` - * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` - * and hold on to the returned [[Row]] before calling `next()`. + * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified + * expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ -case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { +case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) private[this] val exprArray = expressions.toArray - private[this] val mutableRow = new GenericMutableRow(exprArray.size) + private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) def currentValue: Row = mutableRow - def apply(input: Row): Row = { + override def target(row: MutableRow): MutableProjection = { + mutableRow = row + this + } + + override def apply(input: Row): Row = { var i = 0 while (i < exprArray.length) { mutableRow(i) = exprArray(i).eval(input) @@ -76,6 +77,12 @@ class JoinedRow extends Row { private[this] var row1: Row = _ private[this] var row2: Row = _ + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ def apply(r1: Row, r2: Row): Row = { row1 = r1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 74ae723686cfe..c9a63e201ef60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -32,6 +32,16 @@ object Row { * }}} */ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + + /** + * This method can be used to construct a [[Row]] with the given values. + */ + def apply(values: Any*): Row = new GenericRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of values. + */ + def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) } /** @@ -88,15 +98,6 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - - /** - * Experimental - * - * Returns a mutable string builder for the specified column. A given row should return the - * result of any mutations made to the returned buffer next time getString is called for the same - * column. - */ - def getStringBuilder(ordinal: Int): StringBuilder } /** @@ -180,6 +181,35 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + // Custom hashCode function that matches the efficient code generated version. + override def hashCode(): Int = { + var result: Int = 37 + + var i = 0 + while (i < values.length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } + def copy() = this } @@ -187,8 +217,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - def getStringBuilder(ordinal: Int): StringBuilder = ??? - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 5e089f7618e0a..acddf5e9c7004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def eval(input: Row): Any = { children.size match { + case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => function.asInstanceOf[(Any, Any) => Any]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index e787c59e75723..eb8898900d6a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -21,8 +21,16 @@ import scala.language.dynamics import org.apache.spark.sql.catalyst.types.DataType -case object DynamicType extends DataType +/** + * The data type representing [[DynamicRow]] values. + */ +case object DynamicType extends DataType { + def simpleString: String = "dynamic" +} +/** + * Wrap a [[Row]] as a [[DynamicRow]]. + */ case class WrapDynamic(children: Seq[Attribute]) extends Expression { type EvaluatedType = DynamicRow @@ -37,6 +45,11 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { } } +/** + * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language. + * Since the type of the column is not known at compile time, all attributes are converted to + * strings before being passed to the function. + */ class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) extends GenericRow(values) with Dynamic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala new file mode 100644 index 0000000000000..5b398695bf560 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import com.google.common.cache.{CacheLoader, CacheBuilder} + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +/** + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + import scala.tools.reflect.ToolBox + + protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() + + protected val rowType = typeOf[Row] + protected val mutableRowType = typeOf[MutableRow] + protected val genericRowType = typeOf[GenericRow] + protected val genericMutableRowType = typeOf[GenericMutableRow] + + protected val projectionType = typeOf[Projection] + protected val mutableProjectionType = typeOf[MutableProjection] + + private val curId = new java.util.concurrent.atomic.AtomicInteger() + private val javaSeparator = "$" + + /** + * Generates a class for a given input expression. Called when there is not cached code + * already available. + */ + protected def create(in: InType): OutType + + /** + * Canonicalizes an input expression. Used to avoid double caching expressions that differ only + * cosmetically. + */ + protected def canonicalize(in: InType): InType + + /** Binds an input expression to a given input schema */ + protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + + /** + * A cache of generated classes. + * + * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most + * fundamental difference is that a ConcurrentMap persists all elements that are added to it until + * they are explicitly removed. A Cache on the other hand is generally configured to evict entries + * automatically, in order to constrain its memory footprint + */ + protected val cache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build( + new CacheLoader[InType, OutType]() { + override def load(in: InType): OutType = globalLock.synchronized { + create(in) + } + }) + + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType = + apply(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) + + /** + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) + */ + protected def freshName(prefix: String): TermName = { + newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") + } + + /** + * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not + * valid if `nullTerm` is set to `false`. + * @param objectTerm A possibly boxed version of the result of evaluating this expression. + */ + protected case class EvaluatedExpression( + code: Seq[Tree], + nullTerm: TermName, + primitiveTerm: TermName, + objectTerm: TermName) + + /** + * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that + * can be used to determine the result of evaluating the expression on an input row. + */ + def expressionEvaluator(e: Expression): EvaluatedExpression = { + val primitiveTerm = freshName("primitiveTerm") + val nullTerm = freshName("nullTerm") + val objectTerm = freshName("objectTerm") + + implicit class Evaluate1(e: Expression) { + def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${f(eval.primitiveTerm)} + """.children + } + } + + implicit class Evaluate2(expressions: (Expression, Expression)) { + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function from two primitive term names to a tree that evaluates them. + */ + def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = + evaluateAs(expressions._1.dataType)(f) + + def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (expressions._1.dataType != expressions._2.dataType) { + log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") + } + + val eval1 = expressionEvaluator(expressions._1) + val eval2 = expressionEvaluator(expressions._2) + val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + + eval1.code ++ eval2.code ++ + q""" + val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} + val $primitiveTerm: ${termForType(resultType)} = + if($nullTerm) { + ${defaultPrimitive(resultType)} + } else { + $resultCode.asInstanceOf[${termForType(resultType)}] + } + """.children : Seq[Tree] + } + } + + val inputTuple = newTermName(s"i") + + // TODO: Skip generation of null handling code when expression are not nullable. + val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { + case b @ BoundReference(ordinal, dataType, nullable) => + val nullValue = q"$inputTuple.isNullAt($ordinal)" + q""" + val $nullTerm: Boolean = $nullValue + val $primitiveTerm: ${termForType(dataType)} = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${getColumn(inputTuple, dataType, ordinal)} + """.children + + case expressions.Literal(null, dataType) => + q""" + val $nullTerm = true + val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] + """.children + + case expressions.Literal(value: Boolean, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: String, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Int, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Long, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case Cast(e @ BinaryType(), StringType) => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + """.children + + case Cast(child @ NumericType(), IntegerType) => + child.castOrNull(c => q"$c.toInt", IntegerType) + + case Cast(child @ NumericType(), LongType) => + child.castOrNull(c => q"$c.toLong", LongType) + + case Cast(child @ NumericType(), DoubleType) => + child.castOrNull(c => q"$c.toDouble", DoubleType) + + case Cast(child @ NumericType(), FloatType) => + child.castOrNull(c => q"$c.toFloat", IntegerType) + + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case Cast(e, StringType) if e.dataType != TimestampType => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + ${eval.primitiveTerm}.toString + """.children + + case EqualTo(e1, e2) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } + + /* TODO: Fix null semantics. + case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => + val eval = expressionEvaluator(e1) + + val checks = list.map { + case expressions.Literal(v: String, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + case expressions.Literal(v: Int, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + } + + val funcName = newTermName(s"isIn${curId.getAndIncrement()}") + + q""" + def $funcName: Boolean = { + ..${eval.code} + if(${eval.nullTerm}) return false + ..$checks + return false + } + val $nullTerm = false + val $primitiveTerm = $funcName + """.children + */ + + case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } + case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } + case LessThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } + case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } + + case And(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && !${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && !${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = false + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = true + } + """.children + + case Or(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && ${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = true + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = false + } + """.children + + case Not(child) => + // Uh, bad function name... + child.castOrNull(c => q"!$c", BooleanType) + + case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } + case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } + case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } + case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" } + + case IsNotNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} + """.children + + case IsNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} + """.children + + case c @ Coalesce(children) => + q""" + var $nullTerm = true + var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} + """.children ++ + children.map { c => + val eval = expressionEvaluator(c) + q""" + if($nullTerm) { + ..${eval.code} + if(!${eval.nullTerm}) { + $nullTerm = false + $primitiveTerm = ${eval.primitiveTerm} + } + } + """ + } + + case i @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition) + val trueEval = expressionEvaluator(trueValue) + val falseEval = expressionEvaluator(falseValue) + + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} + ..${condEval.code} + if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + ..${trueEval.code} + $nullTerm = ${trueEval.nullTerm} + $primitiveTerm = ${trueEval.primitiveTerm} + } else { + ..${falseEval.code} + $nullTerm = ${falseEval.nullTerm} + $primitiveTerm = ${falseEval.primitiveTerm} + } + """.children + } + + // If there was no match in the partial function above, we fall back on calling the interpreted + // expression evaluator. + val code: Seq[Tree] = + primitiveEvaluation.lift.apply(e).getOrElse { + log.debug(s"No rules to generate $e") + val tree = reify { e } + q""" + val $objectTerm = $tree.eval(i) + val $nullTerm = $objectTerm == null + val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] + """.children + } + + EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + } + + protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { + dataType match { + case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" + } + } + + protected def setColumn( + destinationRow: TermName, + dataType: DataType, + ordinal: Int, + value: TermName) = { + dataType match { + case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => q"$destinationRow.update($ordinal, $value)" + } + } + + protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") + protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + + protected def primitiveForType(dt: DataType) = dt match { + case IntegerType => "Int" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case StringType => "String" + } + + protected def defaultPrimitive(dt: DataType) = dt match { + case BooleanType => ru.Literal(Constant(false)) + case FloatType => ru.Literal(Constant(-1.0.toFloat)) + case StringType => ru.Literal(Constant("")) + case ShortType => ru.Literal(Constant(-1.toShort)) + case LongType => ru.Literal(Constant(1L)) + case ByteType => ru.Literal(Constant(-1.toByte)) + case DoubleType => ru.Literal(Constant(-1.toDouble)) + case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case IntegerType => ru.Literal(Constant(-1)) + case _ => ru.Literal(Constant(null)) + } + + protected def termForType(dt: DataType) = dt match { + case n: NativeType => n.tag + case _ => typeTag[Any] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala new file mode 100644 index 0000000000000..a419fd7ecb39b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * input [[Row]] for a fixed set of [[Expression Expressions]]. + */ +object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + val mutableRowName = newTermName("mutableRow") + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => + val evaluationCode = expressionEvaluator(e) + + evaluationCode.code :+ + q""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i) + else + ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} + """ + } + + val code = + q""" + () => { new $mutableProjectionType { + + private[this] var $mutableRowName: $mutableRowType = + new $genericMutableRowType(${expressions.size}) + + def target(row: $mutableRowType): $mutableProjectionType = { + $mutableRowName = row + this + } + + /* Provide immutable access to the last projected row. */ + def currentValue: $rowType = mutableRow + + def apply(i: $rowType): $rowType = { + ..$projectionCode + mutableRow + } + } } + """ + + log.debug(s"code for ${expressions.mkString(",")}:\n$code") + toolBox.eval(code).asInstanceOf[() => MutableProjection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala new file mode 100644 index 0000000000000..4211998f7511a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import com.typesafe.scalalogging.slf4j.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NumericType} + +/** + * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of + * [[Expression Expressions]]. + */ +object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder]) + + protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { + val a = newTermName("a") + val b = newTermName("b") + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = expressionEvaluator(order.child) + val evalB = expressionEvaluator(order.child) + + val compare = order.child.dataType match { + case _: NumericType => + q""" + val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} + if(comp != 0) { + return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} + } + """ + case StringType => + if (order.direction == Ascending) { + q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + } else { + q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + } + } + + q""" + i = $a + ..${evalA.code} + i = $b + ..${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) q"-1" else q"1"} + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) q"1" else q"-1"} + } else { + $compare + } + """ + } + + val q"class $orderingName extends $orderingType { ..$body }" = reify { + class SpecificOrdering extends Ordering[Row] { + val o = ordering + } + }.tree.children.head + + val code = q""" + class $orderingName extends $orderingType { + ..$body + def compare(a: $rowType, b: $rowType): Int = { + var i: $rowType = null // Holds current row being evaluated. + ..$comparisons + return 0 + } + } + new $orderingName() + """ + logger.debug(s"Generated Ordering: $code") + toolBox.eval(code).asInstanceOf[Ordering[Row]] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala new file mode 100644 index 0000000000000..2a0935c790cf3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. + */ +object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in) + + protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = + BindReferences.bindReference(in, inputSchema) + + protected def create(predicate: Expression): ((Row) => Boolean) = { + val cEval = expressionEvaluator(predicate) + + val code = + q""" + (i: $rowType) => { + ..${cEval.code} + if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + } + """ + + log.debug(s"Generated predicate '$predicate':\n$code") + toolBox.eval(code).asInstanceOf[Row => Boolean] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala new file mode 100644 index 0000000000000..77fa02c13de30 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + + +/** + * Generates bytecode that produces a new [[Row]] object based on a fixed set of input + * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom + * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. + */ +object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + // Make Mutablility optional... + protected def create(expressions: Seq[Expression]): Projection = { + val tupleLength = ru.Literal(Constant(expressions.length)) + val lengthDef = q"final val length = $tupleLength" + + /* TODO: Configurable... + val nullFunctions = + q""" + private final val nullSet = new org.apache.spark.util.collection.BitSet(length) + final def setNullAt(i: Int) = nullSet.set(i) + final def isNullAt(i: Int) = nullSet.get(i) + """ + */ + + val nullFunctions = + q""" + private[this] var nullBits = new Array[Boolean](${expressions.size}) + final def setNullAt(i: Int) = { nullBits(i) = true } + final def isNullAt(i: Int) = nullBits(i) + """.children + + val tupleElements = expressions.zipWithIndex.flatMap { + case (e, i) => + val elementName = newTermName(s"c$i") + val evaluatedExpression = expressionEvaluator(e) + val iLit = ru.Literal(Constant(i)) + + q""" + var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + { + ..${evaluatedExpression.code} + if(${evaluatedExpression.nullTerm}) + setNullAt($iLit) + else + $elementName = ${evaluatedExpression.primitiveTerm} + } + """.children : Seq[Tree] + } + + val iteratorFunction = { + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + } + q"final def iterator = Iterator[Any](..$allColumns)" + } + + val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" + val applyFunction = { + val cases = (0 until expressions.size).map { i => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" + } + q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }" + } + + val updateFunction = { + val cases = expressions.zipWithIndex.map {case (e, i) => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q""" + if(i == $ordinal) { + if(value == null) { + setNullAt(i) + } else { + $elementName = value.asInstanceOf[${termForType(e.dataType)}] + return + } + }""" + } + q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" + } + + val specificAccessorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) return $elementName" :: Nil + case _ => Nil + } + + q""" + final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } + + val specificMutatorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) { $elementName = value; return }" :: Nil + case _ => Nil + } + + q""" + final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { + ..$ifStatements; + $accessorFailure + }""" + } + + val hashValues = expressions.zipWithIndex.map { case (e,i) => + val elementName = newTermName(s"c$i") + val nonNull = e.dataType match { + case BooleanType => q"if ($elementName) 0 else 1" + case ByteType | ShortType | IntegerType => q"$elementName.toInt" + case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" + case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case DoubleType => + q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" + case _ => q"$elementName.hashCode" + } + q"if (isNullAt($i)) 0 else $nonNull" + } + + val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + + val hashCodeFunction = + q""" + override def hashCode(): Int = { + var result: Int = 37 + ..$hashUpdates + result + } + """ + + val columnChecks = (0 until expressions.size).map { i => + val elementName = newTermName(s"c$i") + q"if (this.$elementName != specificType.$elementName) return false" + } + + val equalsFunction = + q""" + override def equals(other: Any): Boolean = other match { + case specificType: SpecificRow => + ..$columnChecks + return true + case other => super.equals(other) + } + """ + + val copyFunction = + q""" + final def copy() = new $genericRowType(this.toArray) + """ + + val classBody = + nullFunctions ++ ( + lengthDef +: + iteratorFunction +: + applyFunction +: + updateFunction +: + equalsFunction +: + hashCodeFunction +: + copyFunction +: + (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) + + val code = q""" + final class SpecificRow(i: $rowType) extends $mutableRowType { + ..$classBody + } + + new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + """ + + log.debug( + s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") + toolBox.eval(code).asInstanceOf[Projection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala new file mode 100644 index 0000000000000..80c7dfd376c96 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.rules +import org.apache.spark.sql.catalyst.util + +/** + * A collection of generators that build custom bytecode at runtime for performing the evaluation + * of catalyst expression. + */ +package object codegen { + + /** + * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala + * 2.10. + */ + protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock + + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ + object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { + val batches = + Batch("CleanExpressions", FixedPoint(20), CleanExpressions) :: Nil + + object CleanExpressions extends rules.Rule[Expression] { + def apply(e: Expression): Expression = e transform { + case Alias(c, _) => c + } + } + } + + /** + * :: DeveloperApi :: + * Dumps the bytecode from a class to the screen using javap. + */ + @DeveloperApi + object DumpByteCode { + import scala.sys.process._ + val dumpDirectory = util.getTempFilePath("sparkSqlByteCode") + dumpDirectory.mkdir() + + def apply(obj: Any): Unit = { + val generatedClass = obj.getClass + val classLoader = + generatedClass + .getClassLoader + .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] + val generatedBytes = classLoader.classBytes(generatedClass.getName) + + val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName) + if (!packageDir.exists()) { packageDir.mkdir() } + + val classFile = + new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class") + + val outfile = new java.io.FileOutputStream(classFile) + outfile.write(generatedBytes) + outfile.close() + + println( + s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 5d3bb25ad568c..c1154eb81c319 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.Map + import org.apache.spark.sql.catalyst.types._ /** @@ -31,8 +33,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { override def foldable = child.foldable && ordinal.foldable override def references = children.flatMap(_.references).toSet def dataType = child.dataType match { - case ArrayType(dt) => dt - case MapType(_, vt) => vt + case ArrayType(dt, _) => dt + case MapType(_, vt, _) => vt } override lazy val resolved = childrenResolved && @@ -61,7 +63,6 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } } else { val baseValue = value.asInstanceOf[Map[Any, _]] - val key = ordinal.eval(input) baseValue.get(key).orNull } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index dd78614754e12..3d41acb79e5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.Map + import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.types._ @@ -84,8 +86,8 @@ case class Explode(attributeNames: Seq[String], child: Expression) (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) private lazy val elementTypes = child.dataType match { - case ArrayType(et) => et :: Nil - case MapType(kt,vt) => kt :: vt :: Nil + case ArrayType(et, _) => et :: Nil + case MapType(kt,vt, _) => kt :: vt :: Nil } // TODO: Move this pattern into Generator. @@ -102,10 +104,10 @@ case class Explode(attributeNames: Seq[String], child: Expression) override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { - case ArrayType(_) => + case ArrayType(_, _) => val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) - case MapType(_, _) => + case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 934bad8c27294..ed69928ae9eb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -28,8 +28,8 @@ object NamedExpression { } /** - * A globally (within this JVM) id for a given named expression. - * Used to identify with attribute output by a relation is being + * A globally unique (within this JVM) id for a given named expression. + * Used to identify which attribute output by a relation is being * referenced in a subsequent computation. */ case class ExprId(id: Long) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index b6f2451b52e1f..55d95991c5f11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -47,4 +47,30 @@ package org.apache.spark.sql.catalyst * ==Evaluation== * The result of expressions can be evaluated using the `Expression.apply(Row)` method. */ -package object expressions +package object expressions { + + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + */ + abstract class Projection extends (Row => Row) + + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` + * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` + * and hold on to the returned [[Row]] before calling `next()`. + */ + abstract class MutableProjection extends Projection { + def currentValue: Row + + /** Uses the given row to store the output of the projection. */ + def target(row: MutableRow): MutableProjection + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 06b94a98d3cd0..5976b0ddf3e03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType object InterpretedPredicate { + def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + apply(BindReferences.bindReference(expression, inputSchema)) + def apply(expression: Expression): (Row => Boolean) = { (r: Row) => expression.eval(r).asInstanceOf[Boolean] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala new file mode 100644 index 0000000000000..ca9642954eb27 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object catalyst { + /** + * A JVM-global lock that should be used to prevent thread safety issues when using things in + * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for + * 2.10.* builds. See SI-6240 for more details. + */ + protected[catalyst] object ScalaReflectionLock + + protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 67833664b35ae..781ba489b44c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 026692abe067d..bc763a4e06e67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec -import org.apache.spark.sql.Logging - import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -104,6 +103,77 @@ object PhysicalOperation extends PredicateHelper { } } +/** + * Matches a logical aggregation that can be performed on distributed data in two steps. The first + * operates on the data in each partition performing partial aggregation for each group. The second + * occurs after the shuffle and completes the aggregation. + * + * This pattern will only match if all aggregate expressions can be computed partially and will + * return the rewritten aggregation expressions for both phases. + * + * The returned values for this match are as follows: + * - Grouping attributes for the final aggregation. + * - Aggregates for the final aggregation. + * - Grouping expressions for the partial aggregation. + * - Partial aggregate expressions. + * - Input to the aggregation. + */ +object PartialAggregation { + type ReturnType = + (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + // Collect all aggregate expressions. + val allAggregates = + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + // Collect all aggregate expressions that can be computed partially. + val partialAggregates = + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + + // Only do partial aggregation if supported by all aggregate expressions. + if (allAggregates.size == partialAggregates.size) { + // Create a map of expressions to their partial evaluations for all aggregate expressions. + val partialEvaluations: Map[Long, SplitEvaluation] = + partialAggregates.map(a => (a.id, a.asPartial)).toMap + + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if partialEvaluations.contains(e.id) => + partialEvaluations(e.id).finalEvaluation + case e: Expression if namedGroupingExpressions.contains(e) => + namedGroupingExpressions(e).toAttribute + }).asInstanceOf[Seq[NamedExpression]] + + val partialComputation = + (namedGroupingExpressions.values ++ + partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + + val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + + Some( + (namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child)) + } else { + None + } + case _ => None + } +} + + /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7b82e19b2e714..0988b0c6d990c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -125,51 +125,10 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - protected def generateSchemaString(schema: Seq[Attribute]): String = { - val builder = new StringBuilder - builder.append("root\n") - val prefix = " |" - schema.foreach { attribute => - val name = attribute.name - val dataType = attribute.dataType - dataType match { - case fields: StructType => - builder.append(s"$prefix-- $name: $StructType\n") - generateSchemaString(fields, s"$prefix |", builder) - case ArrayType(fields: StructType) => - builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") - generateSchemaString(fields, s"$prefix |", builder) - case ArrayType(elementType: DataType) => - builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") - case _ => builder.append(s"$prefix-- $name: $dataType\n") - } - } - - builder.toString() - } - - protected def generateSchemaString( - schema: StructType, - prefix: String, - builder: StringBuilder): StringBuilder = { - schema.fields.foreach { - case StructField(name, fields: StructType, _) => - builder.append(s"$prefix-- $name: $StructType\n") - generateSchemaString(fields, s"$prefix |", builder) - case StructField(name, ArrayType(fields: StructType), _) => - builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") - generateSchemaString(fields, s"$prefix |", builder) - case StructField(name, ArrayType(elementType: DataType), _) => - builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") - case StructField(name, fieldType: DataType, _) => - builder.append(s"$prefix-- $name: $fieldType\n") - } - - builder - } + def schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ - def schemaString: String = generateSchemaString(output) + def schemaString: String = schema.treeString /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index edc37e3877c0e..888cb08e95f06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -26,6 +26,25 @@ import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] { self: Product => + /** + * Estimates of various statistics. The default estimation logic simply lazily multiplies the + * corresponding statistic produced by the children. To override this behavior, override + * `statistics` and assign it an overriden version of `Statistics`. + * + * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the + * performance of the implementations. The reason is that estimations might get triggered in + * performance-critical processes, such as query plan planning. + * + * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it + * defaults to the product of children's `sizeInBytes`. + */ + case class Statistics( + sizeInBytes: BigInt + ) + lazy val statistics: Statistics = Statistics( + sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product + ) + /** * Returns the set of attributes that are referenced by this node * during evaluation. @@ -92,6 +111,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => + override lazy val statistics: Statistics = + throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") + // Leaf nodes by definition cannot reference any input attributes. override def references = Set.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 1537de259c5b4..3cb407217c4c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -177,7 +177,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case StructType(fields) => StructType(fields.map(f => StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable))) - case ArrayType(elemType) => ArrayType(lowerCaseSchema(elemType)) + case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull) case otherType => otherType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 1d5f033f0d274..481a5a4f212b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -35,7 +35,7 @@ abstract class Command extends LeafNode { */ case class NativeCommand(cmd: String) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) + Seq(AttributeReference("result", StringType, nullable = false)()) } /** @@ -43,8 +43,7 @@ case class NativeCommand(cmd: String) extends Command { */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - BoundReference(0, AttributeReference("key", StringType, nullable = false)()), - BoundReference(1, AttributeReference("value", StringType, nullable = false)())) + AttributeReference("", StringType, nullable = false)()) } /** @@ -53,7 +52,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman */ case class ExplainCommand(plan: LogicalPlan) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) + Seq(AttributeReference("plan", StringType, nullable = false)()) } /** @@ -72,7 +71,7 @@ case class DescribeCommand( isExtended: Boolean) extends Command { override def output = Seq( // Column names are based on Hive. - BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()), - BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()), - BoundReference(2, AttributeReference("comment", StringType, nullable = false)())) + AttributeReference("col_name", StringType, nullable = false)(), + AttributeReference("data_type", StringType, nullable = false)(), + AttributeReference("comment", StringType, nullable = false)()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index 1076537bc7602..f8960b3fe7a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e32adb76fe146..6aa407c836aec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql -package catalyst -package rules +package org.apache.spark.sql.catalyst.rules +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -72,7 +71,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } iteration += 1 if (iteration > batch.strategy.maxIterations) { - logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}") + // Only log if this is a rule that is supposed to run more than once. + if (iteration != 2) { + logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + } continue = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index d159ecdd5d781..9a28d035a10a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.Logger - /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are * granted the following interface: @@ -35,5 +33,6 @@ import org.apache.spark.sql.Logger */ package object trees { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected val logger = Logger("catalyst.trees") + protected val logger = + com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees")) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cd4b5e9c1b529..b52ee6d3378a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -23,16 +23,13 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers +import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils /** - * A JVM-global lock that should be used to prevent thread safety issues when using things in - * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for - * 2.10.* builds. See SI-6240 for more details. + * Utility functions for working with DataTypes. */ -protected[catalyst] object ScalaReflectionLock - object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = "StringType" ^^^ StringType | @@ -48,11 +45,13 @@ object DataType extends RegexParsers { "TimestampType" ^^^ TimestampType protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) } protected lazy val structField: Parser[StructField] = @@ -85,6 +84,21 @@ object DataType extends RegexParsers { case Success(result, _) => result case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } + + protected[types] def buildFormattedString( + dataType: DataType, + prefix: String, + builder: StringBuilder): Unit = { + dataType match { + case array: ArrayType => + array.buildFormattedString(prefix, builder) + case struct: StructType => + struct.buildFormattedString(prefix, builder) + case map: MapType => + map.buildFormattedString(prefix, builder) + case _ => + } + } } abstract class DataType { @@ -95,49 +109,65 @@ abstract class DataType { } def isPrimitive: Boolean = false + + def simpleString: String +} + +case object NullType extends DataType { + def simpleString: String = "null" } -case object NullType extends DataType +object NativeType { + def all = Seq( + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + def unapply(dt: DataType): Boolean = all.contains(dt) +} trait PrimitiveType extends DataType { override def isPrimitive = true } abstract class NativeType extends DataType { - type JvmType - @transient val tag: TypeTag[JvmType] - val ordering: Ordering[JvmType] + private[sql] type JvmType + @transient private[sql] val tag: TypeTag[JvmType] + private[sql] val ordering: Ordering[JvmType] - @transient val classTag = ScalaReflectionLock.synchronized { + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { val mirror = runtimeMirror(Utils.getSparkClassLoader) ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) } } case object StringType extends NativeType with PrimitiveType { - type JvmType = String - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "string" } case object BinaryType extends DataType with PrimitiveType { - type JvmType = Array[Byte] + private[sql] type JvmType = Array[Byte] + def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { - type JvmType = Boolean - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "boolean" } case object TimestampType extends NativeType { - type JvmType = Timestamp + private[sql] type JvmType = Timestamp - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = new Ordering[JvmType] { + private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } + + def simpleString: String = "timestamp" } abstract class NumericType extends NativeType with PrimitiveType { @@ -146,7 +176,11 @@ abstract class NumericType extends NativeType with PrimitiveType { // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - val numeric: Numeric[JvmType] + private[sql] val numeric: Numeric[JvmType] +} + +object NumericType { + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } /** Matcher for any expressions that evaluate to [[IntegralType]]s */ @@ -158,39 +192,43 @@ object IntegralType { } abstract class IntegralType extends NumericType { - val integral: Integral[JvmType] + private[sql] val integral: Integral[JvmType] } case object LongType extends IntegralType { - type JvmType = Long - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Long]] - val integral = implicitly[Integral[Long]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Long]] + private[sql] val integral = implicitly[Integral[Long]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "long" } case object IntegerType extends IntegralType { - type JvmType = Int - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Int]] - val integral = implicitly[Integral[Int]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Int]] + private[sql] val integral = implicitly[Integral[Int]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "integer" } case object ShortType extends IntegralType { - type JvmType = Short - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Short]] - val integral = implicitly[Integral[Short]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Short]] + private[sql] val integral = implicitly[Integral[Short]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "short" } case object ByteType extends IntegralType { - type JvmType = Byte - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Byte]] - val integral = implicitly[Integral[Byte]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Byte]] + private[sql] val integral = implicitly[Integral[Byte]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -201,47 +239,159 @@ object FractionalType { } } abstract class FractionalType extends NumericType { - val fractional: Fractional[JvmType] + private[sql] val fractional: Fractional[JvmType] } case object DecimalType extends FractionalType { - type JvmType = BigDecimal - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[BigDecimal]] - val fractional = implicitly[Fractional[BigDecimal]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = BigDecimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[BigDecimal]] + private[sql] val fractional = implicitly[Fractional[BigDecimal]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "decimal" } case object DoubleType extends FractionalType { - type JvmType = Double - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Double]] - val fractional = implicitly[Fractional[Double]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Double]] + private[sql] val fractional = implicitly[Fractional[Double]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "double" } case object FloatType extends FractionalType { - type JvmType = Float - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Float]] - val fractional = implicitly[Fractional[Float]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Float]] + private[sql] val fractional = implicitly[Fractional[Float]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "float" +} + +object ArrayType { + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */ + def apply(elementType: DataType): ArrayType = ArrayType(elementType, false) +} + +/** + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * @param elementType The data type of values. + * @param containsNull Indicates if values have `null` values + */ +case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append( + s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") + DataType.buildFormattedString(elementType, s"$prefix |", builder) + } + + def simpleString: String = "array" } -case class ArrayType(elementType: DataType) extends DataType +/** + * A field inside a StructType. + * @param name The name of this field. + * @param dataType The data type of this field. + * @param nullable Indicates if values of this field can be `null` values. + */ +case class StructField(name: String, dataType: DataType, nullable: Boolean) { -case class StructField(name: String, dataType: DataType, nullable: Boolean) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + DataType.buildFormattedString(dataType, s"$prefix |", builder) + } +} object StructType { - def fromAttributes(attributes: Seq[Attribute]): StructType = { + protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) - } - // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq) + private def validateFields(fields: Seq[StructField]): Boolean = + fields.map(field => field.name).distinct.size == fields.size } case class StructType(fields: Seq[StructField]) extends DataType { - def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + require(StructType.validateFields(fields), "Found fields with the same name.") + + /** + * Returns all field names in a [[Seq]]. + */ + lazy val fieldNames: Seq[String] = fields.map(_.name) + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + /** + * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not + * have a name matching the given name, `null` will be returned. + */ + def apply(name: String): StructField = { + nameToField.get(name).getOrElse( + throw new IllegalArgumentException(s"Field ${name} does not exist.")) + } + + /** + * Returns a [[StructType]] containing [[StructField]]s of the given names. + * Those names which do not have matching fields will be ignored. + */ + def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (!nonExistFields.isEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } + // Preserve the original order of fields. + StructType(fields.filter(f => names.contains(f.name))) + } + + protected[sql] def toAttributes = + fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + + def treeString: String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + fields.foreach(field => field.buildFormattedString(prefix, builder)) + + builder.toString() + } + + def printTreeString(): Unit = println(treeString) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + fields.foreach(field => field.buildFormattedString(prefix, builder)) + } + + def simpleString: String = "struct" } -case class MapType(keyType: DataType, valueType: DataType) extends DataType +object MapType { + /** + * Construct a [[MapType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, true) +} + +/** + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + */ +case class MapType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") + builder.append(s"${prefix}-- value: ${valueType.simpleString} " + + s"(valueContainsNull = ${valueContainsNull})\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) + } + + def simpleString: String = "map" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index c0438dbe52a47..e75373d5a74a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst +import java.math.BigInteger import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ case class PrimitiveData( @@ -148,4 +148,68 @@ class ScalaReflectionSuite extends FunSuite { StructField("_2", StringType, nullable = true))), nullable = true)) } + + test("get data type of a value") { + // BooleanType + assert(BooleanType === typeOfObject(true)) + assert(BooleanType === typeOfObject(false)) + + // BinaryType + assert(BinaryType === typeOfObject("string".getBytes)) + + // StringType + assert(StringType === typeOfObject("string")) + + // ByteType + assert(ByteType === typeOfObject(127.toByte)) + + // ShortType + assert(ShortType === typeOfObject(32767.toShort)) + + // IntegerType + assert(IntegerType === typeOfObject(2147483647)) + + // LongType + assert(LongType === typeOfObject(9223372036854775807L)) + + // FloatType + assert(FloatType === typeOfObject(3.4028235E38.toFloat)) + + // DoubleType + assert(DoubleType === typeOfObject(1.7976931348623157E308)) + + // DecimalType + assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + + // TimestampType + assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-07-25 10:26:00"))) + + // NullType + assert(NullType === typeOfObject(null)) + + def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { + case value: java.math.BigInteger => DecimalType + case value: java.math.BigDecimal => DecimalType + case _ => StringType + } + + assert(DecimalType === typeOfObject1( + new BigInteger("92233720368547758070"))) + assert(DecimalType === typeOfObject1( + new java.math.BigDecimal("1.7976931348623157E318"))) + assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) + + def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { + case value: java.math.BigInteger => DecimalType + } + + intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) + + def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { + case c: Seq[_] => ArrayType(typeOfObject3(c.head)) + } + + assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) + assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 58f8c341e6676..999c9fff38d60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -29,7 +29,11 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class ExpressionEvaluationSuite extends FunSuite { test("literals") { - assert((Literal(1) + Literal(1)).eval(null) === 2) + checkEvaluation(Literal(1), 1) + checkEvaluation(Literal(true), true) + checkEvaluation(Literal(0L), 0L) + checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal(1) + Literal(1), 2) } /** @@ -61,10 +65,8 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - val expr = ! Literal(v, BooleanType) - val result = expr.eval(null) - if (result != answer) - fail(s"$expr should not evaluate to $result, expected: $answer") } + checkEvaluation(!Literal(v, BooleanType), answer) + } } booleanLogicTest("AND", _ && _, @@ -127,6 +129,13 @@ class ExpressionEvaluationSuite extends FunSuite { } } + test("IN") { + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal(null, StringType).like("a"), null) checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) @@ -232,21 +241,21 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) - checkEvaluation("23" cast DoubleType, 23) + checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) - checkEvaluation("23" cast FloatType, 23) - checkEvaluation("23" cast DecimalType, 23) - checkEvaluation("23" cast ByteType, 23) - checkEvaluation("23" cast ShortType, 23) + checkEvaluation("23" cast FloatType, 23f) + checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast ByteType, 23.toByte) + checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) checkEvaluation(Literal(123) cast IntegerType, 123) - checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24) + checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) - checkEvaluation(Literal(23f) + Cast(true, FloatType), 24) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24) - checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24) - checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24) + checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) + checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) + checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} @@ -391,21 +400,21 @@ class ExpressionEvaluationSuite extends FunSuite { val typeMap = MapType(StringType, StringType) val typeArray = ArrayType(StringType) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(null, IntegerType)), null, row) - checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row) + checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) checkEvaluation(GetField(Literal(null, typeS), "a"), null, row) val typeS_notNullable = StructType( @@ -413,10 +422,8 @@ class ExpressionEvaluationSuite extends FunSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS)()), "a").nullable === true) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS_notNullable, nullable = false)()), "a").nullable === false) + assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) assert(GetField(Literal(null, typeS), "a").nullable === true) assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala new file mode 100644 index 0000000000000..245a2e148030c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use code generation for evaluation. + */ +class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val plan = try { + GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val evaluated = GenerateProjection.expressionEvaluator(expression) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + + test("multithreaded eval") { + import scala.concurrent._ + import ExecutionContext.Implicits.global + import scala.concurrent.duration._ + + val futures = (1 to 20).map { _ => + future { + GeneratePredicate(EqualTo(Literal(1), Literal(1))) + GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + } + } + + futures.foreach(Await.result(_, 10.seconds)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala new file mode 100644 index 0000000000000..887aabb1d5fb4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use generated code on mutable rows. + */ +class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + + val plan = try { + GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](expected)) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |${evaluated.code.mkString("\n")} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 4896f1b955f01..e2ae0d25db1a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Combine Limit", FixedPoint(2), + Batch("Combine Limit", FixedPoint(10), CombineLimits) :: - Batch("Constant Folding", FixedPoint(3), + Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, BooleanSimplification) :: Nil diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c309c43804d97..c8016e41256d5 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -32,7 +32,7 @@ Spark Project SQL http://spark.apache.org/ - sql + sql @@ -68,6 +68,11 @@ jackson-databind 2.3.0 + + junit + junit + test + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java new file mode 100644 index 0000000000000..b73a371e93001 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing Lists. + * An ArrayType object comprises two fields, {@code DataType elementType} and + * {@code boolean containsNull}. The field of {@code elementType} is used to specify the type of + * array elements. The field of {@code containsNull} is used to specify if the array has + * {@code null} values. + * + * To create an {@link ArrayType}, + * {@link DataType#createArrayType(DataType)} or + * {@link DataType#createArrayType(DataType, boolean)} + * should be used. + */ +public class ArrayType extends DataType { + private DataType elementType; + private boolean containsNull; + + protected ArrayType(DataType elementType, boolean containsNull) { + this.elementType = elementType; + this.containsNull = containsNull; + } + + public DataType getElementType() { + return elementType; + } + + public boolean isContainsNull() { + return containsNull; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ArrayType arrayType = (ArrayType) o; + + if (containsNull != arrayType.containsNull) return false; + if (!elementType.equals(arrayType.elementType)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = elementType.hashCode(); + result = 31 * result + (containsNull ? 1 : 0); + return result; + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java similarity index 63% rename from examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala rename to sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java index 42ae960bd64a1..7daad60f62a0b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java @@ -15,19 +15,13 @@ * limitations under the License. */ -package org.apache.spark.examples.pythonconverters - -import org.apache.spark.api.python.Converter -import org.apache.hadoop.hbase.client.Result -import org.apache.hadoop.hbase.util.Bytes +package org.apache.spark.sql.api.java; /** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts a HBase Result - * to a String + * The data type representing byte[] values. + * + * {@code BinaryType} is represented by the singleton object {@link DataType#BinaryType}. */ -class HBaseConverter extends Converter[Any, String] { - override def convert(obj: Any): String = { - val result = obj.asInstanceOf[Result] - Bytes.toStringBinary(result.value()) - } +public class BinaryType extends DataType { + protected BinaryType() {} } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java new file mode 100644 index 0000000000000..5a1f52725631b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing boolean and Boolean values. + * + * {@code BooleanType} is represented by the singleton object {@link DataType#BooleanType}. + */ +public class BooleanType extends DataType { + protected BooleanType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java new file mode 100644 index 0000000000000..e5cdf06b21bbe --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing byte and Byte values. + * + * {@code ByteType} is represented by the singleton object {@link DataType#ByteType}. + */ +public class ByteType extends DataType { + protected ByteType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java new file mode 100644 index 0000000000000..3eccddef88134 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The base type of all Spark SQL data types. + * + * To get/create specific data type, users should use singleton objects and factory methods + * provided by this class. + */ +public abstract class DataType { + + /** + * Gets the StringType object. + */ + public static final StringType StringType = new StringType(); + + /** + * Gets the BinaryType object. + */ + public static final BinaryType BinaryType = new BinaryType(); + + /** + * Gets the BooleanType object. + */ + public static final BooleanType BooleanType = new BooleanType(); + + /** + * Gets the TimestampType object. + */ + public static final TimestampType TimestampType = new TimestampType(); + + /** + * Gets the DecimalType object. + */ + public static final DecimalType DecimalType = new DecimalType(); + + /** + * Gets the DoubleType object. + */ + public static final DoubleType DoubleType = new DoubleType(); + + /** + * Gets the FloatType object. + */ + public static final FloatType FloatType = new FloatType(); + + /** + * Gets the ByteType object. + */ + public static final ByteType ByteType = new ByteType(); + + /** + * Gets the IntegerType object. + */ + public static final IntegerType IntegerType = new IntegerType(); + + /** + * Gets the LongType object. + */ + public static final LongType LongType = new LongType(); + + /** + * Gets the ShortType object. + */ + public static final ShortType ShortType = new ShortType(); + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}). + * The field of {@code containsNull} is set to {@code false}. + */ + public static ArrayType createArrayType(DataType elementType) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + + return new ArrayType(elementType, false); + } + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}) and + * whether the array contains null values ({@code containsNull}). + */ + public static ArrayType createArrayType(DataType elementType, boolean containsNull) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + + return new ArrayType(elementType, containsNull); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}) and values + * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}. + */ + public static MapType createMapType(DataType keyType, DataType valueType) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + + return new MapType(keyType, valueType, true); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of + * values ({@code keyType}), and whether values contain any null value + * ({@code valueContainsNull}). + */ + public static MapType createMapType( + DataType keyType, + DataType valueType, + boolean valueContainsNull) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + + return new MapType(keyType, valueType, valueContainsNull); + } + + /** + * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and + * whether values of this field can be null values ({@code nullable}). + */ + public static StructField createStructField(String name, DataType dataType, boolean nullable) { + if (name == null) { + throw new IllegalArgumentException("name should not be null."); + } + if (dataType == null) { + throw new IllegalArgumentException("dataType should not be null."); + } + + return new StructField(name, dataType, nullable); + } + + /** + * Creates a StructType with the given list of StructFields ({@code fields}). + */ + public static StructType createStructType(List fields) { + return createStructType(fields.toArray(new StructField[0])); + } + + /** + * Creates a StructType with the given StructField array ({@code fields}). + */ + public static StructType createStructType(StructField[] fields) { + if (fields == null) { + throw new IllegalArgumentException("fields should not be null."); + } + Set distinctNames = new HashSet(); + for (StructField field: fields) { + if (field == null) { + throw new IllegalArgumentException( + "fields should not contain any null."); + } + + distinctNames.add(field.getName()); + } + if (distinctNames.size() != fields.length) { + throw new IllegalArgumentException("fields should have distinct names."); + } + + return new StructType(fields); + } + +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java new file mode 100644 index 0000000000000..bc54c078d7a4e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing java.math.BigDecimal values. + * + * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}. + */ +public class DecimalType extends DataType { + protected DecimalType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java new file mode 100644 index 0000000000000..f0060d0bcf9f5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing double and Double values. + * + * {@code DoubleType} is represented by the singleton object {@link DataType#DoubleType}. + */ +public class DoubleType extends DataType { + protected DoubleType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java new file mode 100644 index 0000000000000..4a6a37f69176a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing float and Float values. + * + * {@code FloatType} is represented by the singleton object {@link DataType#FloatType}. + */ +public class FloatType extends DataType { + protected FloatType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java new file mode 100644 index 0000000000000..bfd70490bbbbb --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing int and Integer values. + * + * {@code IntegerType} is represented by the singleton object {@link DataType#IntegerType}. + */ +public class IntegerType extends DataType { + protected IntegerType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java new file mode 100644 index 0000000000000..af13a46eb165c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing long and Long values. + * + * {@code LongType} is represented by the singleton object {@link DataType#LongType}. + */ +public class LongType extends DataType { + protected LongType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java new file mode 100644 index 0000000000000..063e6b34abc48 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing Maps. A MapType object comprises two fields, + * {@code DataType keyType}, {@code DataType valueType}, and {@code boolean valueContainsNull}. + * The field of {@code keyType} is used to specify the type of keys in the map. + * The field of {@code valueType} is used to specify the type of values in the map. + * The field of {@code valueContainsNull} is used to specify if map values have + * {@code null} values. + * For values of a MapType column, keys are not allowed to have {@code null} values. + * + * To create a {@link MapType}, + * {@link DataType#createMapType(DataType, DataType)} or + * {@link DataType#createMapType(DataType, DataType, boolean)} + * should be used. + */ +public class MapType extends DataType { + private DataType keyType; + private DataType valueType; + private boolean valueContainsNull; + + protected MapType(DataType keyType, DataType valueType, boolean valueContainsNull) { + this.keyType = keyType; + this.valueType = valueType; + this.valueContainsNull = valueContainsNull; + } + + public DataType getKeyType() { + return keyType; + } + + public DataType getValueType() { + return valueType; + } + + public boolean isValueContainsNull() { + return valueContainsNull; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + MapType mapType = (MapType) o; + + if (valueContainsNull != mapType.valueContainsNull) return false; + if (!keyType.equals(mapType.keyType)) return false; + if (!valueType.equals(mapType.valueType)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = keyType.hashCode(); + result = 31 * result + valueType.hashCode(); + result = 31 * result + (valueContainsNull ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java new file mode 100644 index 0000000000000..7d7604b4e3d2d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing short and Short values. + * + * {@code ShortType} is represented by the singleton object {@link DataType#ShortType}. + */ +public class ShortType extends DataType { + protected ShortType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java new file mode 100644 index 0000000000000..f4ba0c07c9c6e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing String values. + * + * {@code StringType} is represented by the singleton object {@link DataType#StringType}. + */ +public class StringType extends DataType { + protected StringType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java new file mode 100644 index 0000000000000..b48e2a2c5f953 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * A StructField object represents a field in a StructType object. + * A StructField object comprises three fields, {@code String name}, {@code DataType dataType}, + * and {@code boolean nullable}. The field of {@code name} is the name of a StructField. + * The field of {@code dataType} specifies the data type of a StructField. + * The field of {@code nullable} specifies if values of a StructField can contain {@code null} + * values. + * + * To create a {@link StructField}, + * {@link DataType#createStructField(String, DataType, boolean)} + * should be used. + */ +public class StructField { + private String name; + private DataType dataType; + private boolean nullable; + + protected StructField(String name, DataType dataType, boolean nullable) { + this.name = name; + this.dataType = dataType; + this.nullable = nullable; + } + + public String getName() { + return name; + } + + public DataType getDataType() { + return dataType; + } + + public boolean isNullable() { + return nullable; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StructField that = (StructField) o; + + if (nullable != that.nullable) return false; + if (!dataType.equals(that.dataType)) return false; + if (!name.equals(that.name)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + dataType.hashCode(); + result = 31 * result + (nullable ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java new file mode 100644 index 0000000000000..a4b501efd9a10 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.util.Arrays; + +/** + * The data type representing Rows. + * A StructType object comprises an array of StructFields. + * + * To create an {@link StructType}, + * {@link DataType#createStructType(java.util.List)} or + * {@link DataType#createStructType(StructField[])} + * should be used. + */ +public class StructType extends DataType { + private StructField[] fields; + + protected StructType(StructField[] fields) { + this.fields = fields; + } + + public StructField[] getFields() { + return fields; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StructType that = (StructType) o; + + if (!Arrays.equals(fields, that.fields)) return false; + + return true; + } + + @Override + public int hashCode() { + return Arrays.hashCode(fields); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java new file mode 100644 index 0000000000000..06d44c731cdfe --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing java.sql.Timestamp values. + * + * {@code TimestampType} is represented by the singleton object {@link DataType#TimestampType}. + */ +public class TimestampType extends DataType { + protected TimestampType() {} +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java similarity index 95% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java index 53603614518f5..67007a9f0d1a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java @@ -18,4 +18,4 @@ /** * Allows the execution of relational queries, including those expressed in SQL using Spark. */ -package org.apache.spark.sql; \ No newline at end of file +package org.apache.spark.sql.api.java; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2b787e14f3f15..2d407077be303 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -21,47 +21,85 @@ import java.util.Properties import scala.collection.JavaConverters._ +object SQLConf { + val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" + val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" + val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" + val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" + val CODEGEN_ENABLED = "spark.sql.codegen" + + object Deprecated { + val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + } +} + /** - * SQLConf holds mutable config parameters and hints. These can be set and - * queried either by passing SET commands into Spark SQL's DSL - * functions (sql(), hql(), etc.), or by programmatically using setters and - * getters of this class. + * A trait that enables the setting and getting of mutable config parameters/hints. * - * SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads). + * In the presence of a SQLContext, these can be set and queried by passing SET commands + * into Spark SQL's query functions (sql(), hql(), etc.). Otherwise, users of this trait can + * modify the hints by programmatically calling the setters and getters of this trait. + * + * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ trait SQLConf { + import SQLConf._ + + @transient protected[spark] val settings = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? + /** When true tables cached using the in-memory columnar caching will be compressed. */ + private[spark] def useCompression: Boolean = get(COMPRESS_CACHED, "false").toBoolean + /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt + + /** + * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode + * that evaluates expressions found in queries. In general this custom code runs much faster + * than interpreted evaluation, but there are significant start-up costs due to compilation. + * As a result codegen is only benificial when queries run for a long time, or when the same + * expressions are used multiple times. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def codegenEnabled: Boolean = + if (get(CODEGEN_ENABLED, "false") == "true") true else false /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to - * a broadcast value during the physical executions of join operations. Setting this to 0 + * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. - * Hive setting: hive.auto.convert.join.noconditionaltask.size. + * + * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. */ - private[spark] def autoConvertJoinSize: Int = - get("spark.sql.auto.convert.join.size", "10000").toInt + private[spark] def autoBroadcastJoinThreshold: Int = + get(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt - /** A comma-separated list of table names marked to be broadcasted during joins. */ - private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "") + /** + * The default size in bytes to assign to a logical operator's estimation statistics. By default, + * it is set to a larger value than `autoConvertJoinSize`, hence any logical operator without a + * properly implemented estimation of this statistic will not be incorrectly broadcasted in joins. + */ + private[spark] def defaultSizeInBytes: Long = + getOption(DEFAULT_SIZE_IN_BYTES).map(_.toLong).getOrElse(autoBroadcastJoinThreshold + 1) /** ********************** SQLConf functionality methods ************ */ - @transient - private val settings = java.util.Collections.synchronizedMap( - new java.util.HashMap[String, String]()) - def set(props: Properties): Unit = { - props.asScala.foreach { case (k, v) => this.settings.put(k, v) } + settings.synchronized { + props.asScala.foreach { case (k, v) => settings.put(k, v) } + } } def set(key: String, value: String): Unit = { require(key != null, "key cannot be null") - require(value != null, s"value cannot be null for ${key}") + require(value != null, s"value cannot be null for key: $key") settings.put(key, value) } @@ -88,5 +126,5 @@ trait SQLConf { private[spark] def clear() { settings.clear() } - } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4abd89955bd27..86338752a21c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,11 +24,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -86,7 +85,45 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) + + /** + * :: DeveloperApi :: + * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * Example: + * {{{ + * import org.apache.spark.sql._ + * val sqlContext = new org.apache.spark.sql.SQLContext(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = + * sc.textFile("examples/src/main/resources/people.txt").map( + * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val peopleSchemaRDD = sqlContext. applySchema(people, schema) + * peopleSchemaRDD.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * peopleSchemaRDD.registerAsTable("people") + * sqlContext.sql("select name from people").collect.foreach(println) + * }}} + * + * @group userf + */ + @DeveloperApi + def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { + // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied + // schema differs from the existing schema on any field data type. + val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self) + new SchemaRDD(this, logicalPlan) + } /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. @@ -94,7 +131,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration))) + new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. @@ -104,6 +141,19 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + /** + * :: Experimental :: + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + @Experimental + def jsonFile(path: String, schema: StructType): SchemaRDD = { + val json = sparkContext.textFile(path) + jsonRDD(json, schema) + } + /** * :: Experimental :: */ @@ -124,10 +174,28 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a [[SchemaRDD]]. + * + * @group userf */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = - new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) + def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + val appliedSchema = + Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + applySchema(rowRDD, appliedSchema) + } + + /** + * :: Experimental :: + */ + @Experimental + def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + applySchema(rowRDD, appliedSchema) + } /** * :: Experimental :: @@ -160,7 +228,8 @@ class SQLContext(@transient val sparkContext: SparkContext) conf: Configuration = new Configuration()): SchemaRDD = { new SchemaRDD( this, - ParquetRelation.createEmpty(path, ScalaReflection.attributesFor[A], allowExisting, conf)) + ParquetRelation.createEmpty( + path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) } /** @@ -170,11 +239,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - val name = tableName - val newPlan = rdd.logicalPlan transform { - case s @ SparkLogicalPlan(ExistingRdd(_, _), _) => s.copy(tableName = name) - } - catalog.registerTable(None, tableName, newPlan) + catalog.registerTable(None, tableName, rdd.logicalPlan) } /** @@ -196,8 +261,6 @@ class SQLContext(@transient val sparkContext: SparkContext) currentTable.logicalPlan case _ => - val useCompression = - sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) } @@ -212,7 +275,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) - catalog.registerTable(None, tableName, SparkLogicalPlan(e)) + catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) case inMem: InMemoryRelation => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) @@ -234,12 +297,14 @@ class SQLContext(@transient val sparkContext: SparkContext) val sqlContext: SQLContext = self + def codegenEnabled = self.codegenEnabled + def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = CommandStrategy(self) :: TakeOrdered :: - PartialAggregation :: + HashAggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: @@ -297,27 +362,30 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** - * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and - * inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: - Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil + Batch("Add exchange", Once, AddExchange(self)) :: Nil } /** + * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. */ + @DeveloperApi protected abstract class QueryExecution { def logical: LogicalPlan lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... - lazy val sparkPlan = planner(optimizedPlan).next() + lazy val sparkPlan = { + SparkPlan.currentContext.set(self) + planner(optimizedPlan).next() + } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) @@ -337,41 +405,146 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} + |Code Generation: ${executedPlan.codegenEnabled} + |== RDD == + |${stringOrError(toRdd.toDebugString)} """.stripMargin.trim } /** * Peek at the first row of the RDD and infer its schema. - * TODO: consolidate this with the type system developed in SPARK-2060. + * It is only used by PySpark. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { import scala.collection.JavaConversions._ - def typeFor(obj: Any): DataType = obj match { - case c: java.lang.String => StringType - case c: java.lang.Integer => IntegerType - case c: java.lang.Long => LongType - case c: java.lang.Double => DoubleType - case c: java.lang.Boolean => BooleanType - case c: java.util.List[_] => ArrayType(typeFor(c.head)) - case c: java.util.Set[_] => ArrayType(typeFor(c.head)) + + def typeOfComplexValue: PartialFunction[Any, DataType] = { + case c: java.util.Calendar => TimestampType + case c: java.util.List[_] => + ArrayType(typeOfObject(c.head)) case c: java.util.Map[_, _] => val (key, value) = c.head - MapType(typeFor(key), typeFor(value)) + MapType(typeOfObject(key), typeOfObject(value)) case c if c.getClass.isArray => val elem = c.asInstanceOf[Array[_]].head - ArrayType(typeFor(elem)) + ArrayType(typeOfObject(elem)) case c => throw new Exception(s"Object of type $c cannot be used") } - val schema = rdd.first().map { case (fieldName, obj) => - AttributeReference(fieldName, typeFor(obj), true)() + def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue + + val firstRow = rdd.first() + val fields = firstRow.map { + case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true) }.toSeq - val rowRdd = rdd.mapPartitions { iter => - iter.map { map => - new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + applySchemaToPythonRDD(rdd, StructType(fields)) + } + + /** + * Parses the data type in our internal string representation. The data type string should + * have the same format as the one generated by `toString` in scala. + * It is only used by PySpark. + */ + private[sql] def parseDataType(dataTypeString: String): DataType = { + val parser = org.apache.spark.sql.catalyst.types.DataType + parser(dataTypeString) + } + + /** + * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Map[String, _]], + schemaString: String): SchemaRDD = { + val schema = parseDataType(schemaString).asInstanceOf[StructType] + applySchemaToPythonRDD(rdd, schema) + } + + /** + * Apply a schema defined by the schema to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Map[String, _]], + schema: StructType): SchemaRDD = { + // TODO: We should have a better implementation once we do not turn a Python side record + // to a Map. + import scala.collection.JavaConversions._ + import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} + + def needsConversion(dataType: DataType): Boolean = dataType match { + case ByteType => true + case ShortType => true + case FloatType => true + case TimestampType => true + case ArrayType(_, _) => true + case MapType(_, _, _) => true + case StructType(_) => true + case other => false + } + + // Converts value to the type specified by the data type. + // Because Python does not have data types for TimestampType, FloatType, ShortType, and + // ByteType, we need to explicitly convert values in columns of these data types to the desired + // JVM data types. + def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match { + // TODO: We should check nullable + case (null, _) => null + + case (c: java.util.List[_], ArrayType(elementType, _)) => + val converted = c.map { e => convert(e, elementType)} + JListWrapper(converted) + + case (c: java.util.Map[_, _], struct: StructType) => + val row = new GenericMutableRow(struct.fields.length) + struct.fields.zipWithIndex.foreach { + case (field, i) => + val value = convert(c.get(field.name), field.dataType) + row.update(i, value) + } + row + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val converted = c.map { + case (key, value) => + (convert(key, keyType), convert(value, valueType)) + } + JMapWrapper(converted) + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType)) + converted: Seq[Any] + + case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (c: Int, ByteType) => c.toByte + case (c: Int, ShortType) => c.toShort + case (c: Double, FloatType) => c.toFloat + + case (c, _) => c + } + + val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { + rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) }) + } else { + rdd + } + + val rowRdd = convertedRdd.mapPartitions { iter => + val row = new GenericMutableRow(schema.fields.length) + val fieldsWithIndex = schema.fields.zipWithIndex + iter.map { m => + // We cannot use m.values because the order of values returned by m.values may not + // match fields order. + fieldsWithIndex.foreach { + case (field, i) => + val value = + m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull + row.update(i, value) + } + + row: Row } } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) - } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 31d27bb4f0571..420f21fb9c1ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList, Set => JSet} +import java.util.{Map => JMap, List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD @@ -120,6 +119,11 @@ class SchemaRDD( override protected def getDependencies: Seq[Dependency[_]] = List(new OneToOneDependency(queryExecution.toRdd)) + /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ + def schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL @@ -376,39 +380,29 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + import scala.collection.Map + + def toJava(obj: Any, dataType: DataType): Any = dataType match { + case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct) + case array: ArrayType => obj match { + case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava + case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava + case arr if arr != null && arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + case other => other + } + case mt: MapType => obj.asInstanceOf[Map[_, _]].map { + case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + }.asJava + // Pyrolite can handle Timestamp + case other => obj + } def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { val fields = structType.fields.map(field => (field.name, field.dataType)) val map: JMap[String, Any] = new java.util.HashMap row.zip(fields).foreach { - case (obj, (attrName, dataType)) => - dataType match { - case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) - case array @ ArrayType(struct: StructType) => - val arrayValues = obj match { - case seq: Seq[Any] => - seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case list: JList[_] => - list.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case set: JSet[_] => - set.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case arr if arr != null && arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map { - element => rowToMap(element.asInstanceOf[Row], struct) - } - case other => other - } - map.put(attrName, arrayValues) - case array: ArrayType => { - val arrayValues = obj match { - case seq: Seq[Any] => seq.asJava - case other => other - } - map.put(attrName, arrayValues) - } - case other => map.put(attrName, obj) - } + case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType)) } - map } @@ -430,7 +424,8 @@ class SchemaRDD( * @group schema */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))) + new SchemaRDD(sqlContext, + SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))(sqlContext)) } // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index fe81721943202..6a20def475822 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -56,7 +56,7 @@ private[sql] trait SchemaRDDLike { // happen right away to let these side effects take place eagerly. case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => queryExecution.toRdd - SparkLogicalPlan(queryExecution.executedPlan) + SparkLogicalPlan(queryExecution.executedPlan)(sqlContext) case _ => baseLogicalPlan } @@ -123,9 +123,15 @@ private[sql] trait SchemaRDDLike { def saveAsTable(tableName: String): Unit = sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd - /** Returns the output schema in the tree format. */ - def schemaString: String = queryExecution.analyzed.schemaString + /** Returns the schema as a string in the tree format. + * + * @group schema + */ + def schemaString: String = baseSchemaRDD.schema.treeString - /** Prints out the schema in the tree format. */ + /** Prints out the schema. + * + * @group schema + */ def printSchema(): Unit = println(schemaString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 790d9ef22cf16..809dd038f94aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -21,14 +21,15 @@ import java.beans.Introspector import org.apache.hadoop.conf.Configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.types.util.DataTypeConversions +import DataTypeConversions.asScalaDataType; import org.apache.spark.util.Utils /** @@ -72,7 +73,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { conf: Configuration = new Configuration()): JavaSchemaRDD = { new JavaSchemaRDD( sqlContext, - ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf)) + ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf, sqlContext)) } /** @@ -92,7 +93,22 @@ class JavaSQLContext(val sqlContext: SQLContext) { new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow } } - new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) + new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext)) + } + + /** + * :: DeveloperApi :: + * Creates a JavaSchemaRDD from an RDD containing Rows by applying a schema to this RDD. + * It is important to make sure that the structure of every Row of the provided RDD matches the + * provided schema. Otherwise, there will be runtime exception. + */ + @DeveloperApi + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): JavaSchemaRDD = { + val scalaRowRDD = rowRDD.rdd.map(r => r.row) + val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType] + val logicalPlan = + SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) } /** @@ -101,26 +117,52 @@ class JavaSQLContext(val sqlContext: SQLContext) { def parquetFile(path: String): JavaSchemaRDD = new JavaSchemaRDD( sqlContext, - ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) + ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext)) /** - * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. - * - * @group userf */ def jsonFile(path: String): JavaSchemaRDD = jsonRDD(sqlContext.sparkContext.textFile(path)) + /** + * :: Experimental :: + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a JavaSchemaRDD. + */ + @Experimental + def jsonFile(path: String, schema: StructType): JavaSchemaRDD = + jsonRDD(sqlContext.sparkContext.textFile(path), schema) + /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[JavaSchemaRDD]]. + * JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. - * - * @group userf */ - def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) + def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { + val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) + val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val logicalPlan = + SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) + } + + /** + * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a JavaSchemaRDD. + */ + @Experimental + def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { + val appliedScalaSchema = + Option(asScalaDataType(schema)).getOrElse( + JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] + val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val logicalPlan = + SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) + } /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only @@ -138,22 +180,37 @@ class JavaSQLContext(val sqlContext: SQLContext) { val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + case c: Class[_] if c == classOf[java.lang.String] => + (org.apache.spark.sql.StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => + (org.apache.spark.sql.ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => + (org.apache.spark.sql.IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => + (org.apache.spark.sql.LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => + (org.apache.spark.sql.DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => + (org.apache.spark.sql.ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => + (org.apache.spark.sql.FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => + (org.apache.spark.sql.BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => + (org.apache.spark.sql.ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => + (org.apache.spark.sql.IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => + (org.apache.spark.sql.LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => + (org.apache.spark.sql.DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => + (org.apache.spark.sql.ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => + (org.apache.spark.sql.FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => + (org.apache.spark.sql.BooleanType, true) } AttributeReference(property.getName, dataType, nullable)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 8fbf13b8b0150..4d799b4038fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -22,8 +22,10 @@ import java.util.{List => JList} import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import DataTypeConversions._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -53,6 +55,10 @@ class JavaSchemaRDD( override def toString: String = baseSchemaRDD.toString + /** Returns the schema of this JavaSchemaRDD (represented by a StructType). */ + def schema: StructType = + asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType] + // ======================================================================= // Base RDD functions that do NOT change schema // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 9b0dd2176149b..6c67934bda5b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.api.java +import scala.annotation.varargs +import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} +import scala.collection.JavaConversions +import scala.math.BigDecimal + import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** @@ -29,7 +34,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { /** Returns the value of column `i`. */ def get(i: Int): Any = - row(i) + Row.toJavaValue(row(i)) /** Returns true if value at column `i` is NULL. */ def isNullAt(i: Int) = get(i) == null @@ -89,5 +94,57 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { */ def getString(i: Int): String = row.getString(i) + + def canEqual(other: Any): Boolean = other.isInstanceOf[Row] + + override def equals(other: Any): Boolean = other match { + case that: Row => + (that canEqual this) && + row == that.row + case _ => false + } + + override def hashCode(): Int = row.hashCode() } +object Row { + + private def toJavaValue(value: Any): Any = value match { + // For values of this ScalaRow, we will do the conversion when + // they are actually accessed. + case row: ScalaRow => new Row(row) + case map: scala.collection.Map[_, _] => + JavaConversions.mapAsJavaMap( + map.map { + case (key, value) => (toJavaValue(key), toJavaValue(value)) + } + ) + case seq: scala.collection.Seq[_] => + JavaConversions.seqAsJavaList(seq.map(toJavaValue)) + case decimal: BigDecimal => decimal.underlying() + case other => other + } + + // TODO: Consolidate the toScalaValue at here with the scalafy in JsonRDD? + private def toScalaValue(value: Any): Any = value match { + // Values of this row have been converted to Scala values. + case row: Row => row.row + case map: java.util.Map[_, _] => + JMapWrapper(map).map { + case (key, value) => (toScalaValue(key), toScalaValue(value)) + } + case list: java.util.List[_] => + JListWrapper(list).map(toScalaValue) + case decimal: java.math.BigDecimal => BigDecimal(decimal) + case other => other + } + + /** + * Creates a Row with the given values. + */ + @varargs def create(values: Any*): Row = { + // Right now, we cannot use @varargs to annotate the constructor of + // org.apache.spark.sql.api.java.Row. See https://issues.scala-lang.org/browse/SI-8383. + new Row(ScalaRow(values.map(toScalaValue):_*)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 74f5630fbddf1..c416a745739b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -154,6 +154,7 @@ private[sql] object ColumnBuilder { case STRING.typeId => new StringColumnBuilder case BINARY.typeId => new BinaryColumnBuilder case GENERIC.typeId => new GenericColumnBuilder + case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index c1ced8bfa404a..463a1d32d7fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -42,8 +42,8 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sqlContext: SQLContext) - extends UnaryNode with NoBind { + child: SparkPlan) + extends UnaryNode { override def requiredChildDistribution = if (partial) { @@ -56,8 +56,6 @@ case class Aggregate( } } - override def otherCopyArgs = sqlContext :: Nil - // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. private[this] val childOutput = child.output @@ -138,7 +136,7 @@ case class Aggregate( i += 1 } } - val resultProjection = new Projection(resultExpressions, computedSchema) + val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) val aggregateResults = new GenericMutableRow(computedAggregates.length) var i = 0 @@ -152,7 +150,7 @@ case class Aggregate( } else { child.execute().mapPartitions { iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) var currentRow: Row = null while (iter.hasNext) { @@ -175,7 +173,8 @@ case class Aggregate( private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) private[this] val resultProjection = - new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + new InterpretedMutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) private[this] val joinedRow = new JoinedRow override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 00010ef6e798a..30712f03cab4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning = newPartitioning @@ -42,12 +42,14 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions, child.output) + @transient val hashExpressions = + newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part) + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) @@ -60,7 +62,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(row => mutablePair.update(row, null)) } val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part) + val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._1) @@ -71,7 +73,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(null, r)) } val partitioner = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner) + val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 47b3d00262dbb..c386fd121c5de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -47,23 +47,26 @@ case class Generate( } } - override def output = + // This must be a val since the generator output expr ids are not preserved by serialization. + override val output = if (join) child.output ++ generatorOutput else generatorOutput + val boundGenerator = BindReferences.bindReference(generator, child.output) + override def execute() = { if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = - new Projection(child.output ++ nullValues, child.output) + newProjection(child.output ++ nullValues, child.output) val joinProjection = - new Projection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generator.output, child.output ++ generator.output) val joinedRow = new JoinedRow iter.flatMap {row => - val outputRows = generator.eval(row) + val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { outerProjection(row) :: Nil } else { @@ -72,7 +75,7 @@ case class Generate( } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row))) + child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala new file mode 100644 index 0000000000000..4a26934c49c93 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types._ + +case class AggregateEvaluation( + schema: Seq[Attribute], + initialValues: Seq[Expression], + update: Seq[Expression], + result: Expression) + +/** + * :: DeveloperApi :: + * Alternate version of aggregation that leverages projection and thus code generation. + * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto + * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +@DeveloperApi +case class GeneratedAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def output = aggregateExpressions.map(_.toAttribute) + + override def execute() = { + val aggregatesToCompute = aggregateExpressions.flatMap { a => + a.collect { case agg: AggregateExpression => agg} + } + + val computeFunctions = aggregatesToCompute.map { + case c @ Count(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val initialValue = Literal(0L) + val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val result = currentCount + + AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case Sum(expr) => + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialValue = Cast(Literal(0L), expr.dataType) + + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + val result = currentSum + + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case a @ Average(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialCount = Literal(0L) + val initialSum = Cast(Literal(0L), expr.dataType) + val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + + val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + + AggregateEvaluation( + currentCount :: currentSum :: Nil, + initialCount :: initialSum :: Nil, + updateCount :: updateSum :: Nil, + result + ) + } + + val computationSchema = computeFunctions.flatMap(_.schema) + + val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => agg.id -> func.result + }.toMap + + val namedGroups = groupingExpressions.zipWithIndex.map { + case (ne: NamedExpression, _) => (ne, ne) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + } + + val groupMap: Map[Expression, Attribute] = + namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap + + // The set of expressions that produce the final output given the aggregation buffer and the + // grouping expressions. + val resultExpressions = aggregateExpressions.map(_.transform { + case e: Expression if resultMap.contains(e.id) => resultMap(e.id) + case e: Expression if groupMap.contains(e) => groupMap(e) + }) + + child.execute().mapPartitions { iter => + // Builds a new custom class for holding the results of aggregation for a group. + val initialValues = computeFunctions.flatMap(_.initialValues) + val newAggregationBuffer = newProjection(initialValues, child.output) + log.info(s"Initial values: ${initialValues.mkString(",")}") + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") + + // A projection that is used to update the aggregate values for a group given a new tuple. + // This projection should be targeted at the current values for the group and then applied + // to a joined row of the current values with the new input row. + val updateExpressions = computeFunctions.flatMap(_.update) + val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output + val updateProjection = newMutableProjection(updateExpressions, updateSchema)() + log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") + + // A projection that produces the final result, given a computation. + val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + log.info(s"Result Projection: ${resultExpressions.mkString(",")}") + + val joinedRow = new JoinedRow + + if (groupingExpressions.isEmpty) { + // TODO: Codegening anything other than the updateProjection is probably over kill. + val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + var currentRow: Row = null + updateProjection.target(buffer) + + while (iter.hasNext) { + currentRow = iter.next() + updateProjection(joinedRow(buffer, currentRow)) + } + + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(buffer)) + } else { + val buffers = new java.util.HashMap[Row, MutableRow]() + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupProjection(currentRow) + var currentBuffer = buffers.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + buffers.put(currentGroup, currentBuffer) + } + // Target the projection at the current aggregation buffer and then project the updated + // values. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val resultIterator = buffers.entrySet.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext = resultIterator.hasNext + + def next() = { + val currentGroup = resultIterator.next() + resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 27dc091b85812..21cbbc9772a00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -18,22 +18,55 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Logging, Row} + + +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.BaseRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ + +object SparkPlan { + protected[sql] val currentContext = new ThreadLocal[SQLContext]() +} + /** * :: DeveloperApi :: */ @DeveloperApi -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { self: Product => + /** + * A handle to the SQL Context that was used to create this plan. Since many operators need + * access to the sqlContext for RDD operations or configuration this field is automatically + * populated by the query planning infrastructure. + */ + @transient + protected val sqlContext = SparkPlan.currentContext.get() + + protected def sparkContext = sqlContext.sparkContext + + // sqlContext will be null when we are being deserialized on the slaves. In this instance + // the value of codegenEnabled will be set by the desserializer after the constructor has run. + val codegenEnabled: Boolean = if (sqlContext != null) { + sqlContext.codegenEnabled + } else { + false + } + + /** Overridden make copy also propogates sqlContext to copied plan. */ + override def makeCopy(newArgs: Array[AnyRef]): this.type = { + SparkPlan.currentContext.set(sqlContext) + super.makeCopy(newArgs) + } + // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! @@ -51,8 +84,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { */ def executeCollect(): Array[Row] = execute().map(_.copy()).collect() - protected def buildRow(values: Seq[Any]): Row = - new GenericRow(values.toArray) + protected def newProjection( + expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + GenerateProjection(expressions, inputSchema) + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if(codegenEnabled) { + GenerateMutableProjection(expressions, inputSchema) + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + + + protected def newPredicate( + expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { + if (codegenEnabled) { + GeneratePredicate(expression, inputSchema) + } else { + InterpretedPredicate(expression, inputSchema) + } + } + + protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + if (codegenEnabled) { + GenerateOrdering(order, inputSchema) + } else { + new RowOrdering(order, inputSchema) + } + } } /** @@ -66,8 +137,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { * linking. */ @DeveloperApi -case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "SparkLogicalPlan") - extends BaseRelation with MultiInstanceRelation { +case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { def output = alreadyPlanned.output override def references = Set.empty @@ -78,9 +149,15 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "Spar alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) case _ => sys.error("Multiple instance of the same relation detected.") - }, tableName) - .asInstanceOf[this.type] + })(sqlContext).asInstanceOf[this.type] } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) + } private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c078e71fe0290..8bec015c7b465 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ @@ -39,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition)(sqlContext) :: Nil + planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -47,9 +47,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { /** * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be * evaluated by matching hash keys. + * + * This strategy applies a simple optimization based on the estimates of the physical sizes of + * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * estimated physical size smaller than the user-settable threshold + * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the + * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be + * ''broadcasted'' to all of the executors involved in the join, as a + * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they + * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { - private[this] def broadcastHashJoin( + + private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: LogicalPlan, @@ -57,102 +67,102 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition: Option[Expression], side: BuildSide) = { val broadcastHashJoin = execution.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext) + leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } - def broadcastTables: Seq[String] = sqlContext.joinBroadcastTables.split(",").toBuffer - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys( - Inner, - leftKeys, - rightKeys, - condition, - left, - right @ PhysicalOperation(_, _, b: BaseRelation)) - if broadcastTables.contains(b.tableName) => - broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) - case ExtractEquiJoinKeys( - Inner, - leftKeys, - rightKeys, - condition, - left @ PhysicalOperation(_, _, b: BaseRelation), - right) - if broadcastTables.contains(b.tableName) => - broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.autoBroadcastJoinThreshold > 0 && + left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + BuildRight + } else { + BuildLeft + } val hashJoin = execution.ShuffledHashJoin( - leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + execution.HashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case _ => Nil } } - object PartialAggregation extends Strategy { + object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a }) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p }) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[Long, SplitEvaluation] = - partialAggregates.map(a => (a.id, a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(e.id) => - partialEvaluations(e.id).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + // Aggregations that can be performed in two phases, before and after the shuffle. - // Construct two phased aggregation. - execution.Aggregate( + // Cases where all aggregates can be codegened. + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) + if canBeCodeGened( + allAggregates(partialComputation) ++ + allAggregates(rewrittenAggregateExpressions)) && + codegenEnabled => + execution.GeneratedAggregate( partial = false, - namedGroupingExpressions.values.map(_.toAttribute).toSeq, + namedGroupingAttributes, rewrittenAggregateExpressions, - execution.Aggregate( + execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, - planLater(child))(sqlContext))(sqlContext) :: Nil - } else { - Nil - } + planLater(child))) :: Nil + + // Cases where some aggregate can not be codegened + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) => + execution.Aggregate( + partial = false, + namedGroupingAttributes, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))) :: Nil + case _ => Nil } + + def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { + case _: Sum | _: Count => false + case _ => true + } + + def allAggregates(exprs: Seq[Expression]) = + exprs.flatMap(_.collect { case a: AggregateExpression => a }) } object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } } @@ -171,16 +181,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) - def convertToCatalyst(a: Any): Any = a match { - case s: Seq[Any] => s.map(convertToCatalyst) - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil + execution.TakeOrdered(limit, order, planLater(child)) :: Nil case _ => Nil } } @@ -190,11 +194,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // TODO: need to support writing to other types of files. Unify the below code paths. case logical.WriteToFile(path, child) => val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { @@ -223,7 +227,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil + ParquetTableScan(_, relation, filters)) :: Nil case _ => Nil } @@ -261,20 +265,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => - val dataAsRdd = - sparkContext.parallelize(data.map(r => - new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) - execution.ExistingRdd(output, dataAsRdd) :: Nil + ExistingRdd( + output, + ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child))(sqlContext) :: Nil + execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => - execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil - case logical.Except(left,right) => - execution.Except(planLater(left),planLater(right)) :: Nil + execution.Union(unionChildren.map(planLater)) :: Nil + case logical.Except(left, right) => + execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => @@ -283,7 +286,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case SparkLogicalPlan(existingPlan, _) => existingPlan :: Nil + case SparkLogicalPlan(existingPlan) => existingPlan :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 966d8f95fc83c..0027f3cf1fc79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output = projectList.map(_.toAttribute) - override def execute() = child.execute().mapPartitions { iter => - @transient val reusableProjection = new MutableProjection(projectList) - iter.map(reusableProjection) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) + + def execute() = child.execute().mapPartitions { iter => + val resuableProjection = buildProjection() + iter.map(resuableProjection) } } @@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output = child.output - override def execute() = child.execute().mapPartitions { iter => - iter.filter(condition.eval(_).asInstanceOf[Boolean]) + @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + + def execute() = child.execute().mapPartitions { iter => + iter.filter(conditionEvaluator) } } @@ -72,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: * :: DeveloperApi :: */ @DeveloperApi -case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan { +case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output = children.head.output - override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) - - override def otherCopyArgs = sqlContext :: Nil + override def execute() = sparkContext.union(children.map(_.execute())) } /** @@ -89,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex * repartition all the data to a single partition to compute the global limit. */ @DeveloperApi -case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) +case class Limit(limit: Int, child: SparkPlan) extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again - override def otherCopyArgs = sqlContext :: Nil - override def output = child.output /** @@ -148,7 +148,7 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext iter.take(limit).map(row => mutablePair.update(false, row)) } val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part) + val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.mapPartitions(_.take(limit).map(_._2)) } @@ -161,20 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sqlContext: SQLContext) extends UnaryNode { - override def otherCopyArgs = sqlContext :: Nil +case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output = child.output - @transient - lazy val ordering = new RowOrdering(sortOrder) + val ordering = new RowOrdering(sortOrder, child.output) + // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1) + override def execute() = sparkContext.makeRDD(executeCollect(), 1) } /** @@ -189,15 +187,13 @@ case class Sort( override def requiredChildDistribution = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - @transient - lazy val ordering = new RowOrdering(sortOrder) override def execute() = attachTree(this, "sort") { - // TODO: Optimize sorting operation? child.execute() - .mapPartitions( - iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, - preservesPartitioning = true) + .mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) } override def output = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 98d2f89c8ae71..9293239131d52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SQLConf, SQLContext} trait Command { /** @@ -44,28 +45,53 @@ trait Command { case class SetCommand( key: Option[String], value: Option[String], output: Seq[Attribute])( @transient context: SQLContext) - extends LeafNode with Command { + extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => - context.set(k, v) - Array(k -> v) + if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") + context.set(SQLConf.SHUFFLE_PARTITIONS, v) + Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + } else { + context.set(k, v) + Array(s"$k=$v") + } // Query the value bound to key k. case (Some(k), _) => - Array(k -> context.getOption(k).getOrElse("")) + // TODO (lian) This is just a workaround to make the Simba ODBC driver work. + // Should remove this once we get the ODBC driver updated. + if (k == "-v") { + val hiveJars = Seq( + "hive-exec-0.12.0.jar", + "hive-service-0.12.0.jar", + "hive-common-0.12.0.jar", + "hive-hwi-0.12.0.jar", + "hive-0.12.0.jar").mkString(":") + + Array( + "system:java.class.path=" + hiveJars, + "system:sun.java.command=shark.SharkServer2") + } + else { + Array(s"$k=${context.getOption(k).getOrElse("")}") + } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - context.getAll + context.getAll.map { case (k, v) => + s"$k=$v" + } case _ => throw new IllegalArgumentException() } def execute(): RDD[Row] = { - val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) } context.sparkContext.parallelize(rows, 1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c6fbd6d2f6930..f31df051824d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -41,13 +41,13 @@ package object debug { */ @DeveloperApi implicit class DebugQuery(query: SchemaRDD) { - def debug(implicit sc: SparkContext): Unit = { + def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[Long]() val debugPlan = plan transform { case s: SparkPlan if !visited.contains(s.id) => visited += s.id - DebugNode(sc, s) + DebugNode(s) } println(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { @@ -57,9 +57,7 @@ package object debug { } } - private[sql] case class DebugNode( - @transient sparkContext: SparkContext, - child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { def references = Set.empty def output = child.output @@ -107,7 +105,9 @@ package object debug { var i = 0 while (i < numColumns) { val value = currentRow(i) - columnStats(i).elementTypes += HashSet(value.getClass.getName) + if (value != null) { + columnStats(i).elementTypes += HashSet(value.getClass.getName) + } i += 1 } currentRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 7d1f11caae838..82f0a74b630bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide case object BuildRight extends BuildSide trait HashJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] val buildSide: BuildSide @@ -56,9 +58,9 @@ trait HashJoin { def output = left.output ++ right.output - @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) @transient lazy val streamSideKeyGenerator = - () => new MutableProjection(streamedKeys, streamedPlan.output) + newMutableProjection(streamedKeys, streamedPlan.output) def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. @@ -70,7 +72,7 @@ trait HashJoin { while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { val newMatchList = new ArrayBuffer[Row]() @@ -134,6 +136,185 @@ trait HashJoin { } } +/** + * Constant Value for Binary Join Node + */ +object HashOuterJoin { + val DUMMY_LIST = Seq[Row](null) + val EMPTY_LIST = Seq[Row]() +} + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def output = left.output ++ right.output + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { + // TODO: Use Spark's HashMap implementation. + val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + existingMatchList += currentRow.copy() + } + + hashTable.toMap[Row, ArrayBuffer[Row]] + } + + def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case x => throw new Exception(s"Need to add implementation for $x") + } + } + } +} + /** * :: DeveloperApi :: * Performs an inner hash join of two child relations by first shuffling the data using the join @@ -187,7 +368,7 @@ case class LeftSemiJoinHash( while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { hashSet.add(rowKey) @@ -217,9 +398,8 @@ case class BroadcastHashJoin( rightKeys: Seq[Expression], buildSide: BuildSide, left: SparkPlan, - right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin { + right: SparkPlan) extends BinaryNode with HashJoin { - override def otherCopyArgs = sqlContext :: Nil override def outputPartitioning: Partitioning = left.outputPartitioning @@ -228,7 +408,7 @@ case class BroadcastHashJoin( @transient lazy val broadcastFuture = future { - sqlContext.sparkContext.broadcast(buildPlan.executeCollect()) + sparkContext.broadcast(buildPlan.executeCollect()) } def execute() = { @@ -248,14 +428,11 @@ case class BroadcastHashJoin( @DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - def output = left.output /** The Streamed Relation */ @@ -271,7 +448,7 @@ case class LeftSemiJoinBNL( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow @@ -300,8 +477,14 @@ case class LeftSemiJoinBNL( case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { def output = left.output ++ right.output - def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { - case (l: Row, r: Row) => buildRow(l ++ r) + def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } } } @@ -310,14 +493,20 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod */ @DeveloperApi case class BroadcastNestedLoopJoin( - streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - (@transient sqlContext: SQLContext) - extends BinaryNode { + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. - override def outputPartitioning: Partitioning = streamed.outputPartitioning + /** BuildRight means the right relation <=> the broadcast relation. */ + val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } - override def otherCopyArgs = sqlContext :: Nil + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output = { joinType match { @@ -332,11 +521,6 @@ case class BroadcastNestedLoopJoin( } } - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - @transient lazy val boundCondition = InterpretedPredicate( condition @@ -345,58 +529,80 @@ case class BroadcastNestedLoopJoin( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] // TODO: Use Spark's BitSet. - val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) streamedIter.foreach { streamedRow => var i = 0 - var matched = false + var streamRowMatched = false while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matchedRows += buildRow(streamedRow ++ broadcastedRow) - matched = true - includedBroadcastTuples += i + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => } i += 1 } - if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { - matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null)) + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => } } Iterator((matchedRows, includedBroadcastTuples)) } - val includedBroadcastTuples = streamedPlusMatches.map(_._2) + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) val allIncludedBroadcastTuples = if (includedBroadcastTuples.count == 0) { new scala.collection.mutable.BitSet(broadcastedRelation.value.size) } else { - streamedPlusMatches.map(_._2).reduce(_ ++ _) + includedBroadcastTuples.reduce(_ ++ _) } - val rightOuterMatches: Seq[Row] = - if (joinType == RightOuter || joinType == FullOuter) { - broadcastedRelation.value.zipWithIndex.filter { - case (row, i) => !allIncludedBroadcastTuples.contains(i) - }.map { - // TODO: Use projection. - case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) + case _ => + } } - } else { - Vector() + i += 1 } + arrBuf.toSeq + } // TODO: Breaks lineage. - sqlContext.sparkContext.union( - streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches)) + sparkContext.union( + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index b48c70ee73a27..70db1ebd3a3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.json +import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal @@ -25,30 +26,25 @@ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.Logging private[sql] object JsonRDD extends Logging { + private[sql] def jsonStringToRow( + json: RDD[String], + schema: StructType): RDD[Row] = { + parseJson(json).map(parsed => asRow(parsed, schema)) + } + private[sql] def inferSchema( json: RDD[String], - samplingRatio: Double = 1.0): LogicalPlan = { + samplingRatio: Double = 1.0): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) - val baseSchema = createSchema(allKeys) - - createLogicalPlan(json, baseSchema) - } - - private def createLogicalPlan( - json: RDD[String], - baseSchema: StructType): LogicalPlan = { - val schema = nullTypeToStringType(baseSchema) - - SparkLogicalPlan(ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema)))) + createSchema(allKeys) } private def createSchema(allKeys: Set[(String, DataType)]): StructType = { @@ -72,8 +68,8 @@ private[sql] object JsonRDD extends Logging { val (topLevel, structLike) = values.partition(_.size == 1) val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { - case ArrayType(StructType(Nil)) => false - case ArrayType(_) => true + case ArrayType(StructType(Nil), _) => false + case ArrayType(_, _) => true case struct: StructType => false case _ => true } @@ -87,7 +83,8 @@ private[sql] object JsonRDD extends Logging { val structType = makeStruct(nestedFields, prefix :+ name) val dataType = resolved.get(prefix :+ name).get dataType match { - case array: ArrayType => Some(StructField(name, ArrayType(structType), nullable = true)) + case array: ArrayType => + Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true)) case struct: StructType => Some(StructField(name, structType, nullable = true)) // dataType is StringType means that we have resolved type conflicts involving // primitive types and complex types. So, the type of name has been relaxed to @@ -106,6 +103,22 @@ private[sql] object JsonRDD extends Logging { makeStruct(resolved.keySet.toSeq, Nil) } + private[sql] def nullTypeToStringType(struct: StructType): StructType = { + val fields = struct.fields.map { + case StructField(fieldName, dataType, nullable) => { + val newType = dataType match { + case NullType => StringType + case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) + case struct: StructType => nullTypeToStringType(struct) + case other: DataType => other + } + StructField(fieldName, newType, nullable) + } + } + + StructType(fields) + } + /** * Returns the most general data type for two given data types. */ @@ -136,8 +149,8 @@ private[sql] object JsonRDD extends Logging { case StructField(name, _, _) => name }) } - case (ArrayType(elementType1), ArrayType(elementType2)) => - ArrayType(compatibleType(elementType1, elementType2)) + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // TODO: We should use JsonObjectStringType to mark that values of field will be // strings and every string is a Json object. case (_, _) => StringType @@ -145,18 +158,13 @@ private[sql] object JsonRDD extends Logging { } } - private def typeOfPrimitiveValue(value: Any): DataType = { - value match { - case value: java.lang.String => StringType - case value: java.lang.Integer => IntegerType - case value: java.lang.Long => LongType + private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { + ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. case value: java.math.BigInteger => DecimalType - case value: java.lang.Double => DoubleType + // DecimalType's JVMType is scala BigDecimal. case value: java.math.BigDecimal => DecimalType - case value: java.lang.Boolean => BooleanType - case null => NullType // Unexpected data type. case _ => StringType } @@ -169,12 +177,13 @@ private[sql] object JsonRDD extends Logging { * treat the element as String. */ private def typeOfArray(l: Seq[Any]): ArrayType = { + val containsNull = l.exists(v => v == null) val elements = l.flatMap(v => Option(v)) if (elements.isEmpty) { // If this JSON array is empty, we use NullType as a placeholder. // If this array is not empty in other JSON objects, we can resolve // the type after we have passed through all JSON objects. - ArrayType(NullType) + ArrayType(NullType, containsNull) } else { val elementType = elements.map { e => e match { @@ -186,7 +195,7 @@ private[sql] object JsonRDD extends Logging { } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - ArrayType(elementType) + ArrayType(elementType, containsNull) } } @@ -213,15 +222,16 @@ private[sql] object JsonRDD extends Logging { case (key: String, array: Seq[_]) => { // The value associated with the key is an array. typeOfArray(array) match { - case ArrayType(StructType(Nil)) => { + case ArrayType(StructType(Nil), containsNull) => { // The elements of this arrays are structs. array.asInstanceOf[Seq[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { case (k, dataType) => (s"$key.$k", dataType) - } :+ (key, ArrayType(StructType(Nil))) + } :+ (key, ArrayType(StructType(Nil), containsNull)) } - case ArrayType(elementType) => (key, ArrayType(elementType)) :: Nil + case ArrayType(elementType, containsNull) => + (key, ArrayType(elementType, containsNull)) :: Nil } } case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil @@ -259,8 +269,11 @@ private[sql] object JsonRDD extends Logging { // the ObjectMapper will take the last value associated with this duplicate key. // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() - iter.map(record => mapper.readValue(record, classOf[java.util.Map[String, Any]])) - }).map(scalafy).map(_.asInstanceOf[Map[String, Any]]) + iter.map { record => + val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) + parsed.asInstanceOf[Map[String, Any]] + } + }) } private def toLong(value: Any): Long = { @@ -331,7 +344,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case ArrayType(elementType) => + case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case StringType => toString(value) case IntegerType => value.asInstanceOf[IntegerType.JvmType] @@ -345,6 +358,7 @@ private[sql] object JsonRDD extends Logging { } private def asRow(json: Map[String,Any], schema: StructType): Row = { + // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { // StructType @@ -353,7 +367,7 @@ private[sql] object JsonRDD extends Logging { v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull) // ArrayType(StructType) - case (StructField(name, ArrayType(structType: StructType), _), i) => + case (StructField(name, ArrayType(structType: StructType, _), _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( v => v.asInstanceOf[Seq[Any]].map( @@ -367,32 +381,4 @@ private[sql] object JsonRDD extends Logging { row } - - private def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable) => { - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType) => ArrayType(StringType) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } - StructField(fieldName, newType, nullable) - } - } - - StructType(fields) - } - - private def asAttributes(struct: StructType): Seq[AttributeReference] = { - struct.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) - } - - private def asStruct(attributes: Seq[AttributeReference]): StructType = { - val fields = attributes.map { - case AttributeReference(name, dataType, nullable) => StructField(name, dataType, nullable) - } - - StructType(fields) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala new file mode 100644 index 0000000000000..0995a4eb6299f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.DeveloperApi + +/** + * Allows the execution of relational queries, including those expressed in SQL using Spark. + * + * @groupname dataType Data types + * @groupdesc Spark SQL data types. + * @groupprio dataType -3 + * @groupname field Field + * @groupprio field -2 + * @groupname row Row + * @groupprio row -1 + */ +package object sql { + + protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging + + /** + * :: DeveloperApi :: + * + * Represents one row of output from a relational operator. + * @group row + */ + @DeveloperApi + type Row = catalyst.expressions.Row + + /** + * :: DeveloperApi :: + * + * A [[Row]] object can be constructed by providing field values. Example: + * {{{ + * import org.apache.spark.sql._ + * + * // Create a Row from values. + * Row(value1, value2, value3, ...) + * // Create a Row from a Seq of values. + * Row.fromSeq(Seq(value1, value2, ...)) + * }}} + * + * A value of a row can be accessed through both generic access by ordinal, + * which will incur boxing overhead for primitives, as well as native primitive access. + * An example of generic access by ordinal: + * {{{ + * import org.apache.spark.sql._ + * + * val row = Row(1, true, "a string", null) + * // row: Row = [1,true,a string,null] + * val firstValue = row(0) + * // firstValue: Any = 1 + * val fourthValue = row(3) + * // fourthValue: Any = null + * }}} + * + * For native primitive access, it is invalid to use the native primitive interface to retrieve + * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a + * value that might be null. + * An example of native primitive access: + * {{{ + * // using the row from the previous example. + * val firstValue = row.getInt(0) + * // firstValue: Int = 1 + * val isNull = row.isNullAt(3) + * // isNull: Boolean = true + * }}} + * + * Interfaces related to native primitive access are: + * + * `isNullAt(i: Int): Boolean` + * + * `getInt(i: Int): Int` + * + * `getLong(i: Int): Long` + * + * `getDouble(i: Int): Double` + * + * `getFloat(i: Int): Float` + * + * `getBoolean(i: Int): Boolean` + * + * `getShort(i: Int): Short` + * + * `getByte(i: Int): Byte` + * + * `getString(i: Int): String` + * + * Fields in a [[Row]] object can be extracted in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + * + * @group row + */ + @DeveloperApi + val Row = catalyst.expressions.Row + + /** + * :: DeveloperApi :: + * + * The base type of all Spark SQL data types. + * + * @group dataType + */ + @DeveloperApi + type DataType = catalyst.types.DataType + + /** + * :: DeveloperApi :: + * + * The data type representing `String` values + * + * @group dataType + */ + @DeveloperApi + val StringType = catalyst.types.StringType + + /** + * :: DeveloperApi :: + * + * The data type representing `Array[Byte]` values. + * + * @group dataType + */ + @DeveloperApi + val BinaryType = catalyst.types.BinaryType + + /** + * :: DeveloperApi :: + * + * The data type representing `Boolean` values. + * + *@group dataType + */ + @DeveloperApi + val BooleanType = catalyst.types.BooleanType + + /** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Timestamp` values. + * + * @group dataType + */ + @DeveloperApi + val TimestampType = catalyst.types.TimestampType + + /** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * + * @group dataType + */ + @DeveloperApi + val DecimalType = catalyst.types.DecimalType + + /** + * :: DeveloperApi :: + * + * The data type representing `Double` values. + * + * @group dataType + */ + @DeveloperApi + val DoubleType = catalyst.types.DoubleType + + /** + * :: DeveloperApi :: + * + * The data type representing `Float` values. + * + * @group dataType + */ + @DeveloperApi + val FloatType = catalyst.types.FloatType + + /** + * :: DeveloperApi :: + * + * The data type representing `Byte` values. + * + * @group dataType + */ + @DeveloperApi + val ByteType = catalyst.types.ByteType + + /** + * :: DeveloperApi :: + * + * The data type representing `Int` values. + * + * @group dataType + */ + @DeveloperApi + val IntegerType = catalyst.types.IntegerType + + /** + * :: DeveloperApi :: + * + * The data type representing `Long` values. + * + * @group dataType + */ + @DeveloperApi + val LongType = catalyst.types.LongType + + /** + * :: DeveloperApi :: + * + * The data type representing `Short` values. + * + * @group dataType + */ + @DeveloperApi + val ShortType = catalyst.types.ShortType + + /** + * :: DeveloperApi :: + * + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * + * @group dataType + */ + @DeveloperApi + type ArrayType = catalyst.types.ArrayType + + /** + * :: DeveloperApi :: + * + * An [[ArrayType]] object can be constructed with two ways, + * {{{ + * ArrayType(elementType: DataType, containsNull: Boolean) + * }}} and + * {{{ + * ArrayType(elementType: DataType) + * }}} + * For `ArrayType(elementType)`, the field of `containsNull` is set to `false`. + * + * @group dataType + */ + @DeveloperApi + val ArrayType = catalyst.types.ArrayType + + /** + * :: DeveloperApi :: + * + * The data type representing `Map`s. A [[MapType]] object comprises three fields, + * `keyType: [[DataType]]`, `valueType: [[DataType]]` and `valueContainsNull: Boolean`. + * The field of `keyType` is used to specify the type of keys in the map. + * The field of `valueType` is used to specify the type of values in the map. + * The field of `valueContainsNull` is used to specify if values of this map has `null` values. + * For values of a MapType column, keys are not allowed to have `null` values. + * + * @group dataType + */ + @DeveloperApi + type MapType = catalyst.types.MapType + + /** + * :: DeveloperApi :: + * + * A [[MapType]] object can be constructed with two ways, + * {{{ + * MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) + * }}} and + * {{{ + * MapType(keyType: DataType, valueType: DataType) + * }}} + * For `MapType(keyType: DataType, valueType: DataType)`, + * the field of `valueContainsNull` is set to `true`. + * + * @group dataType + */ + @DeveloperApi + val MapType = catalyst.types.MapType + + /** + * :: DeveloperApi :: + * + * The data type representing [[Row]]s. + * A [[StructType]] object comprises a [[Seq]] of [[StructField]]s. + * + * @group dataType + */ + @DeveloperApi + type StructType = catalyst.types.StructType + + /** + * :: DeveloperApi :: + * + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Those names do not have matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ + @DeveloperApi + val StructType = catalyst.types.StructType + + /** + * :: DeveloperApi :: + * + * A [[StructField]] object represents a field in a [[StructType]] object. + * A [[StructField]] object comprises three fields, `name: [[String]]`, `dataType: [[DataType]]`, + * and `nullable: Boolean`. The field of `name` is the name of a `StructField`. The field of + * `dataType` specifies the data type of a `StructField`. + * The field of `nullable` specifies if values of a `StructField` can contain `null` values. + * + * @group field + */ + @DeveloperApi + type StructField = catalyst.types.StructField + + /** + * :: DeveloperApi :: + * + * A [[StructField]] object can be constructed by + * {{{ + * StructField(name: String, dataType: DataType, nullable: Boolean) + * }}} + * + * @group dataType + */ + @DeveloperApi + val StructField = catalyst.types.StructField +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index de8fe2dae38f6..0a3b59cbc233a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -75,21 +75,21 @@ private[sql] object CatalystConverter { val fieldType: DataType = field.dataType fieldType match { // For native JVM types we use a converter with native arrays - case ArrayType(elementType: NativeType) => { + case ArrayType(elementType: NativeType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) } // This is for other types of arrays, including those with nested fields - case ArrayType(elementType: DataType) => { + case ArrayType(elementType: DataType, false) => { new CatalystArrayConverter(elementType, fieldIndex, parent) } case StructType(fields: Seq[StructField]) => { new CatalystStructConverter(fields.toArray, fieldIndex, parent) } - case MapType(keyType: DataType, valueType: DataType) => { + case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { new CatalystMapConverter( Array( new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), - new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)), + new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)), fieldIndex, parent) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 9c4771d1a9846..b3bae5db0edbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -27,6 +27,7 @@ import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} @@ -45,7 +46,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} */ private[sql] case class ParquetRelation( path: String, - @transient conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { + @transient conf: Option[Configuration], + @transient sqlContext: SQLContext) + extends LeafNode with MultiInstanceRelation { self: Product => @@ -59,7 +62,7 @@ private[sql] case class ParquetRelation( /** Attributes */ override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) - override def newInstance = ParquetRelation(path).asInstanceOf[this.type] + override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, @@ -68,6 +71,9 @@ private[sql] case class ParquetRelation( p.path == path && p.output == output case _ => false } + + // TODO: Use data from the footers. + override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes) } private[sql] object ParquetRelation { @@ -104,13 +110,14 @@ private[sql] object ParquetRelation { */ def create(pathString: String, child: LogicalPlan, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { if (!child.resolved) { throw new UnresolvedException[LogicalPlan]( child, "Attempt to create Parquet table from unresolved child (when schema is not available)") } - createEmpty(pathString, child.output, false, conf) + createEmpty(pathString, child.output, false, conf, sqlContext) } /** @@ -125,14 +132,15 @@ private[sql] object ParquetRelation { def createEmpty(pathString: String, attributes: Seq[Attribute], allowExisting: Boolean, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { val path = checkPath(pathString, allowExisting, conf) if (conf.get(ParquetOutputFormat.COMPRESSION) == null) { conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name()) } ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString, Some(conf)) { + new ParquetRelation(path.toString, Some(conf), sqlContext) { override val output = attributes } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ea74320d06c86..759a2a586b926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -51,14 +51,20 @@ import org.apache.spark.{Logging, SerializableWritable, TaskContext} * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. */ case class ParquetTableScan( - // note: output cannot be transient, see - // https://issues.apache.org/jira/browse/SPARK-1367 - output: Seq[Attribute], + attributes: Seq[Attribute], relation: ParquetRelation, - columnPruningPred: Seq[Expression])( - @transient val sqlContext: SQLContext) + columnPruningPred: Seq[Expression]) extends LeafNode { + // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes + // by exprId. note: output cannot be transient, see + // https://issues.apache.org/jira/browse/SPARK-1367 + val output = attributes.map { a => + relation.output + .find(o => o.exprId == a.exprId) + .getOrElse(sys.error(s"Invalid parquet attribute $a in ${relation.output.mkString(",")}")) + } + override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -99,8 +105,6 @@ case class ParquetTableScan( .filter(_ != null) // Parquet's record filters may produce null values } - override def otherCopyArgs = sqlContext :: Nil - /** * Applies a (candidate) projection. * @@ -110,10 +114,9 @@ case class ParquetTableScan( def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { val success = validateProjection(prunedAttributes) if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext) + ParquetTableScan(prunedAttributes, relation, columnPruningPred) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") - this } } @@ -150,8 +153,7 @@ case class ParquetTableScan( case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, - overwrite: Boolean = false)( - @transient val sqlContext: SQLContext) + overwrite: Boolean = false) extends UnaryNode with SparkHadoopMapReduceUtil { /** @@ -171,7 +173,7 @@ case class InsertIntoParquetTable( val writeSupport = if (child.output.map(_.dataType).forall(_.isPrimitive)) { - logger.debug("Initializing MutableRowWriteSupport") + log.debug("Initializing MutableRowWriteSupport") classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] } else { classOf[org.apache.spark.sql.parquet.RowWriteSupport] @@ -203,8 +205,6 @@ case class InsertIntoParquetTable( override def output = child.output - override def otherCopyArgs = sqlContext :: Nil - /** * Stores the given Row RDD as a Hadoop file. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 39294a3f4bf5a..6d4ce32ac5bfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -172,10 +172,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - case t @ ArrayType(_) => writeArray( + case t @ ArrayType(_, false) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) - case t @ MapType(_, _) => writeMap( + case t @ MapType(_, _, _) => writeMap( t, value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) case t @ StructType(_) => writeStruct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index d4599da711254..837ea7695dbb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.Job +import org.apache.spark.sql.test.TestSQLContext import parquet.example.data.{GroupWriter, Group} import parquet.example.data.simple.SimpleGroup @@ -103,7 +104,7 @@ private[sql] object ParquetTestData { val testDir = Utils.createTempDir() val testFilterDir = Utils.createTempDir() - lazy val testData = new ParquetRelation(testDir.toURI.toString) + lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext) val testNestedSchema1 = // based on blogpost example, source: @@ -202,8 +203,10 @@ private[sql] object ParquetTestData { val testNestedDir3 = Utils.createTempDir() val testNestedDir4 = Utils.createTempDir() - lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString) - lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString) + lazy val testNestedData1 = + new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext) + lazy val testNestedData2 = + new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext) def writeFile() = { testDir.delete() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 58370b955a5ec..aaef1a1d474fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -116,7 +116,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - new ArrayType(toDataType(field)) + ArrayType(toDataType(field), containsNull = false) } case ParquetOriginalType.MAP => { assert( @@ -130,7 +130,9 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - new MapType(keyType, valueType) + // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true + // at here. + MapType(keyType, valueType) } case _ => { // Note: the order of these checks is important! @@ -140,10 +142,12 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - new MapType(keyType, valueType) + // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true + // at here. + MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType val elementType = toDataType(groupType.getFields.apply(0)) - new ArrayType(elementType) + ArrayType(elementType, containsNull = false) } else { // everything else: StructType val fields = groupType .getFields @@ -151,7 +155,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ptype.getName, toDataType(ptype), ptype.getRepetition != Repetition.REQUIRED)) - new StructType(fields) + StructType(fields) } } } @@ -234,7 +238,7 @@ private[parquet] object ParquetTypesConverter extends Logging { new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) }.getOrElse { ctype match { - case ArrayType(elementType) => { + case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, @@ -248,7 +252,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } new ParquetGroupType(repetition, name, fields) } - case MapType(keyType, valueType) => { + case MapType(keyType, valueType, _) => { val parquetKeyType = fromDataType( keyType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala new file mode 100644 index 0000000000000..77353f4eb0227 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types.util + +import org.apache.spark.sql._ +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField} + +import scala.collection.JavaConverters._ + +protected[sql] object DataTypeConversions { + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + def asJavaStructField(scalaStructField: StructField): JStructField = { + JDataType.createStructField( + scalaStructField.name, + asJavaDataType(scalaStructField.dataType), + scalaStructField.nullable) + } + + /** + * Returns the equivalent DataType in Java for the given DataType in Scala. + */ + def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + case StringType => JDataType.StringType + case BinaryType => JDataType.BinaryType + case BooleanType => JDataType.BooleanType + case TimestampType => JDataType.TimestampType + case DecimalType => JDataType.DecimalType + case DoubleType => JDataType.DoubleType + case FloatType => JDataType.FloatType + case ByteType => JDataType.ByteType + case IntegerType => JDataType.IntegerType + case LongType => JDataType.LongType + case ShortType => JDataType.ShortType + + case arrayType: ArrayType => JDataType.createArrayType( + asJavaDataType(arrayType.elementType), arrayType.containsNull) + case mapType: MapType => JDataType.createMapType( + asJavaDataType(mapType.keyType), + asJavaDataType(mapType.valueType), + mapType.valueContainsNull) + case structType: StructType => JDataType.createStructType( + structType.fields.map(asJavaStructField).asJava) + } + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + def asScalaStructField(javaStructField: JStructField): StructField = { + StructField( + javaStructField.getName, + asScalaDataType(javaStructField.getDataType), + javaStructField.isNullable) + } + + /** + * Returns the equivalent DataType in Scala for the given DataType in Java. + */ + def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + case stringType: org.apache.spark.sql.api.java.StringType => + StringType + case binaryType: org.apache.spark.sql.api.java.BinaryType => + BinaryType + case booleanType: org.apache.spark.sql.api.java.BooleanType => + BooleanType + case timestampType: org.apache.spark.sql.api.java.TimestampType => + TimestampType + case decimalType: org.apache.spark.sql.api.java.DecimalType => + DecimalType + case doubleType: org.apache.spark.sql.api.java.DoubleType => + DoubleType + case floatType: org.apache.spark.sql.api.java.FloatType => + FloatType + case byteType: org.apache.spark.sql.api.java.ByteType => + ByteType + case integerType: org.apache.spark.sql.api.java.IntegerType => + IntegerType + case longType: org.apache.spark.sql.api.java.LongType => + LongType + case shortType: org.apache.spark.sql.api.java.ShortType => + ShortType + + case arrayType: org.apache.spark.sql.api.java.ArrayType => + ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) + case mapType: org.apache.spark.sql.api.java.MapType => + MapType( + asScalaDataType(mapType.getKeyType), + asScalaDataType(mapType.getValueType), + mapType.isValueContainsNull) + case structType: org.apache.spark.sql.api.java.StructType => + StructType(structType.getFields.map(asScalaStructField)) + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java new file mode 100644 index 0000000000000..3c92906d82864 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaApplySchemaSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + + @Before + public void setUp() { + javaCtx = new JavaSparkContext("local", "JavaApplySchemaSuite"); + javaSqlCtx = new JavaSQLContext(javaCtx); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + javaSqlCtx = null; + } + + public static class Person implements Serializable { + private String name; + private int age; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getAge() { + return age; + } + + public void setAge(int age) { + this.age = age; + } + } + + @Test + public void applySchema() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return Row.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataType.createStructField("name", DataType.StringType, false)); + fields.add(DataType.createStructField("age", DataType.IntegerType, false)); + StructType schema = DataType.createStructType(fields); + + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); + schemaRDD.registerAsTable("people"); + List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); + + List expected = new ArrayList(2); + expected.add(Row.create("Michael", 29)); + expected.add(Row.create("Yin", 28)); + + Assert.assertEquals(expected, actual); + } + + @Test + public void applySchemaToJSON() { + JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( + "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + + "\"boolean\":true, \"null\":null}", + "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + + "\"boolean\":false, \"null\":null}")); + List fields = new ArrayList(7); + fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true)); + fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); + fields.add(DataType.createStructField("double", DataType.DoubleType, true)); + fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); + fields.add(DataType.createStructField("long", DataType.LongType, true)); + fields.add(DataType.createStructField("null", DataType.StringType, true)); + fields.add(DataType.createStructField("string", DataType.StringType, true)); + StructType expectedSchema = DataType.createStructType(fields); + List expectedResult = new ArrayList(2); + expectedResult.add( + Row.create( + new BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")); + expectedResult.add( + Row.create( + new BigDecimal("92233720368547758069"), + false, + 1.7976931348623157E305, + 11, + 21474836469L, + null, + "this is another simple string.")); + + JavaSchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD); + StructType actualSchema1 = schemaRDD1.schema(); + Assert.assertEquals(expectedSchema, actualSchema1); + schemaRDD1.registerAsTable("jsonTable1"); + List actual1 = javaSqlCtx.sql("select * from jsonTable1").collect(); + Assert.assertEquals(expectedResult, actual1); + + JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); + StructType actualSchema2 = schemaRDD2.schema(); + Assert.assertEquals(expectedSchema, actualSchema2); + schemaRDD1.registerAsTable("jsonTable2"); + List actual2 = javaSqlCtx.sql("select * from jsonTable2").collect(); + Assert.assertEquals(expectedResult, actual2); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java new file mode 100644 index 0000000000000..52d07b5425cc3 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class JavaRowSuite { + private byte byteValue; + private short shortValue; + private int intValue; + private long longValue; + private float floatValue; + private double doubleValue; + private BigDecimal decimalValue; + private boolean booleanValue; + private String stringValue; + private byte[] binaryValue; + private Timestamp timestampValue; + + @Before + public void setUp() { + byteValue = (byte)127; + shortValue = (short)32767; + intValue = 2147483647; + longValue = 9223372036854775807L; + floatValue = (float)3.4028235E38; + doubleValue = 1.7976931348623157E308; + decimalValue = new BigDecimal("1.7976931348623157E328"); + booleanValue = true; + stringValue = "this is a string"; + binaryValue = stringValue.getBytes(); + timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); + } + + @Test + public void constructSimpleRow() { + Row simpleRow = Row.create( + byteValue, // ByteType + new Byte(byteValue), + shortValue, // ShortType + new Short(shortValue), + intValue, // IntegerType + new Integer(intValue), + longValue, // LongType + new Long(longValue), + floatValue, // FloatType + new Float(floatValue), + doubleValue, // DoubleType + new Double(doubleValue), + decimalValue, // DecimalType + booleanValue, // BooleanType + new Boolean(booleanValue), + stringValue, // StringType + binaryValue, // BinaryType + timestampValue, // TimestampType + null // null + ); + + Assert.assertEquals(byteValue, simpleRow.getByte(0)); + Assert.assertEquals(byteValue, simpleRow.get(0)); + Assert.assertEquals(byteValue, simpleRow.getByte(1)); + Assert.assertEquals(byteValue, simpleRow.get(1)); + Assert.assertEquals(shortValue, simpleRow.getShort(2)); + Assert.assertEquals(shortValue, simpleRow.get(2)); + Assert.assertEquals(shortValue, simpleRow.getShort(3)); + Assert.assertEquals(shortValue, simpleRow.get(3)); + Assert.assertEquals(intValue, simpleRow.getInt(4)); + Assert.assertEquals(intValue, simpleRow.get(4)); + Assert.assertEquals(intValue, simpleRow.getInt(5)); + Assert.assertEquals(intValue, simpleRow.get(5)); + Assert.assertEquals(longValue, simpleRow.getLong(6)); + Assert.assertEquals(longValue, simpleRow.get(6)); + Assert.assertEquals(longValue, simpleRow.getLong(7)); + Assert.assertEquals(longValue, simpleRow.get(7)); + // When we create the row, we do not do any conversion + // for a float/double value, so we just set the delta to 0. + Assert.assertEquals(floatValue, simpleRow.getFloat(8), 0); + Assert.assertEquals(floatValue, simpleRow.get(8)); + Assert.assertEquals(floatValue, simpleRow.getFloat(9), 0); + Assert.assertEquals(floatValue, simpleRow.get(9)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(10), 0); + Assert.assertEquals(doubleValue, simpleRow.get(10)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(11), 0); + Assert.assertEquals(doubleValue, simpleRow.get(11)); + Assert.assertEquals(decimalValue, simpleRow.get(12)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(13)); + Assert.assertEquals(booleanValue, simpleRow.get(13)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(14)); + Assert.assertEquals(booleanValue, simpleRow.get(14)); + Assert.assertEquals(stringValue, simpleRow.getString(15)); + Assert.assertEquals(stringValue, simpleRow.get(15)); + Assert.assertEquals(binaryValue, simpleRow.get(16)); + Assert.assertEquals(timestampValue, simpleRow.get(17)); + Assert.assertEquals(true, simpleRow.isNullAt(18)); + Assert.assertEquals(null, simpleRow.get(18)); + } + + @Test + public void constructComplexRow() { + // Simple array + List simpleStringArray = Arrays.asList( + stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); + + // Simple map + Map simpleMap = new HashMap(); + simpleMap.put(stringValue + " (1)", longValue); + simpleMap.put(stringValue + " (2)", longValue - 1); + simpleMap.put(stringValue + " (3)", longValue - 2); + + // Simple struct + Row simpleStruct = Row.create( + doubleValue, stringValue, timestampValue, null); + + // Complex array + List> arrayOfMaps = Arrays.asList(simpleMap); + List arrayOfRows = Arrays.asList(simpleStruct); + + // Complex map + Map, Row> complexMap = new HashMap, Row>(); + complexMap.put(arrayOfRows, simpleStruct); + + // Complex struct + Row complexStruct = Row.create( + simpleStringArray, + simpleMap, + simpleStruct, + arrayOfMaps, + arrayOfRows, + complexMap, + null); + Assert.assertEquals(simpleStringArray, complexStruct.get(0)); + Assert.assertEquals(simpleMap, complexStruct.get(1)); + Assert.assertEquals(simpleStruct, complexStruct.get(2)); + Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); + Assert.assertEquals(arrayOfRows, complexStruct.get(4)); + Assert.assertEquals(complexMap, complexStruct.get(5)); + Assert.assertEquals(null, complexStruct.get(6)); + + // A very complex row + Row complexRow = Row.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); + Assert.assertEquals(arrayOfMaps, complexRow.get(0)); + Assert.assertEquals(arrayOfRows, complexRow.get(1)); + Assert.assertEquals(complexMap, complexRow.get(2)); + Assert.assertEquals(complexStruct, complexRow.get(3)); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java new file mode 100644 index 0000000000000..d099a48a1f4b6 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.util.List; +import java.util.ArrayList; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.sql.types.util.DataTypeConversions; + +public class JavaSideDataTypeConversionSuite { + public void checkDataType(DataType javaDataType) { + org.apache.spark.sql.catalyst.types.DataType scalaDataType = + DataTypeConversions.asScalaDataType(javaDataType); + DataType actual = DataTypeConversions.asJavaDataType(scalaDataType); + Assert.assertEquals(javaDataType, actual); + } + + @Test + public void createDataTypes() { + // Simple DataTypes. + checkDataType(DataType.StringType); + checkDataType(DataType.BinaryType); + checkDataType(DataType.BooleanType); + checkDataType(DataType.TimestampType); + checkDataType(DataType.DecimalType); + checkDataType(DataType.DoubleType); + checkDataType(DataType.FloatType); + checkDataType(DataType.ByteType); + checkDataType(DataType.IntegerType); + checkDataType(DataType.LongType); + checkDataType(DataType.ShortType); + + // Simple ArrayType. + DataType simpleJavaArrayType = DataType.createArrayType(DataType.StringType, true); + checkDataType(simpleJavaArrayType); + + // Simple MapType. + DataType simpleJavaMapType = DataType.createMapType(DataType.StringType, DataType.LongType); + checkDataType(simpleJavaMapType); + + // Simple StructType. + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); + DataType simpleJavaStructType = DataType.createStructType(simpleFields); + checkDataType(simpleJavaStructType); + + // Complex StructType. + List complexFields = new ArrayList(); + complexFields.add(DataType.createStructField("simpleArray", simpleJavaArrayType, true)); + complexFields.add(DataType.createStructField("simpleMap", simpleJavaMapType, true)); + complexFields.add(DataType.createStructField("simpleStruct", simpleJavaStructType, true)); + complexFields.add(DataType.createStructField("boolean", DataType.BooleanType, false)); + DataType complexJavaStructType = DataType.createStructType(complexFields); + checkDataType(complexJavaStructType); + + // Complex ArrayType. + DataType complexJavaArrayType = DataType.createArrayType(complexJavaStructType, true); + checkDataType(complexJavaArrayType); + + // Complex MapType. + DataType complexJavaMapType = + DataType.createMapType(complexJavaStructType, complexJavaArrayType, false); + checkDataType(complexJavaMapType); + } + + @Test + public void illegalArgument() { + // ArrayType + try { + DataType.createArrayType(null, true); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + + // MapType + try { + DataType.createMapType(null, DataType.StringType); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createMapType(DataType.StringType, null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createMapType(null, null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + + // StructField + try { + DataType.createStructField(null, DataType.StringType, true); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createStructField("name", null, true); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createStructField(null, null, true); + } catch (IllegalArgumentException expectedException) { + } + + // StructType + try { + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + simpleFields.add(null); + DataType.createStructType(simpleFields); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + DataType.createStructType(simpleFields); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala new file mode 100644 index 0000000000000..cf7d79f42db1d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +class DataTypeSuite extends FunSuite { + + test("construct an ArrayType") { + val array = ArrayType(StringType) + + assert(ArrayType(StringType, false) === array) + } + + test("construct an MapType") { + val map = MapType(StringType, IntegerType) + + assert(MapType(StringType, IntegerType, true) === map) + } + + test("extract fields from a StructType") { + val struct = StructType( + StructField("a", IntegerType, true) :: + StructField("b", LongType, false) :: + StructField("c", StringType, true) :: + StructField("d", FloatType, true) :: Nil) + + assert(StructField("b", LongType, false) === struct("b")) + + intercept[IllegalArgumentException] { + struct("e") + } + + val expectedStruct = StructType( + StructField("b", LongType, false) :: + StructField("d", FloatType, true) :: Nil) + + assert(expectedStruct === struct(Set("b", "d"))) + intercept[IllegalArgumentException] { + struct(Set("b", "d", "e", "f")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e17ecc87fd52a..037890682f7b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner} +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -class JoinSuite extends QueryTest { +class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData @@ -36,6 +40,56 @@ class JoinSuite extends QueryTest { assert(planned.size === 1) } + test("join operator selection") { + def assertJoin(sqlString: String, c: Class[_]): Any = { + val rdd = sql(sqlString) + val physical = rdd.queryExecution.sparkPlan + val operators = physical.collect { + case j: ShuffledHashJoin => j + case j: HashOuterJoin => j + case j: LeftSemiJoinHash => j + case j: BroadcastHashJoin => j + case j: LeftSemiJoinBNL => j + case j: CartesianProduct => j + case j: BroadcastNestedLoopJoin => j + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + } + } + + val cases1 = Seq( + ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a where key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]) + // TODO add BroadcastNestedLoopJoin + ) + cases1.foreach { c => assertJoin(c._1, c._2) } + } + test("multiple-key equi-join is hash-join") { val x = testData2.as('x) val y = testData2.as('y) @@ -116,6 +170,33 @@ class JoinSuite extends QueryTest { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) } test("right outer join") { @@ -127,11 +208,38 @@ class JoinSuite extends QueryTest { (4, "d", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) } test("full outer join") { - val left = upperCaseData.where('N <= 4).as('left) - val right = upperCaseData.where('N >= 3).as('right) + upperCaseData.where('N <= 4).registerAsTable("left") + upperCaseData.where('N >= 3).registerAsTable("right") + + val left = UnresolvedRelation(None, "left", None) + val right = UnresolvedRelation(None, "right", None) checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), @@ -141,5 +249,25 @@ class JoinSuite extends QueryTest { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8e1e1971d968b..1fd8d27b34c59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -45,6 +45,7 @@ class QueryTest extends PlanTest { |${rdd.queryExecution} |== Exception == |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala new file mode 100644 index 0000000000000..651cb735ab7d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -0,0 +1,46 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow + +class RowSuite extends FunSuite { + + test("create row") { + val expected = new GenericMutableRow(4) + expected.update(0, 2147483647) + expected.update(1, "this is a string") + expected.update(2, false) + expected.update(3, null) + val actual1 = Row(2147483647, "this is a string", false, null) + assert(expected.size === actual1.size) + assert(expected.getInt(0) === actual1.getInt(0)) + assert(expected.getString(1) === actual1.getString(1)) + assert(expected.getBoolean(2) === actual1.getBoolean(2)) + assert(expected(3) === actual1(3)) + + val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) + assert(expected.size === actual2.size) + assert(expected.getInt(0) === actual2.getInt(0)) + assert(expected.getString(1) === actual2.getString(1)) + assert(expected.getBoolean(2) === actual2.getBoolean(2)) + assert(expected(3) === actual2(3)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 08293f7f0ca30..1a58d73d9e7f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -54,10 +54,10 @@ class SQLConfSuite extends QueryTest { assert(get(testKey, testVal + "_") == testVal) assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - sql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - sql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + sql("set some.property=20") + assert(get("some.property", "0") == "20") + sql("set some.property = 40") + assert(get("some.property", "0") == "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" @@ -70,4 +70,9 @@ class SQLConfSuite extends QueryTest { clear() } + test("deprecated property") { + clear() + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6736189c96d4b..5c571d35d1bb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test._ /* Implicits */ @@ -424,26 +422,107 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(testKey, testVal), - Seq(testKey + testKey, testVal + testVal)) + Seq(s"$testKey=$testVal"), + Seq(s"${testKey + testKey}=${testVal + testVal}")) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(nonexistentKey, "")) + Seq(Seq(s"$nonexistentKey=")) ) clear() } + + test("apply schema") { + val schema1 = StructType( + StructField("f1", IntegerType, false) :: + StructField("f2", StringType, false) :: + StructField("f3", BooleanType, false) :: + StructField("f4", IntegerType, true) :: Nil) + + val rowRDD1 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(values(0).toInt, values(1), values(2).toBoolean, v4) + } + + val schemaRDD1 = applySchema(rowRDD1, schema1) + schemaRDD1.registerAsTable("applySchema1") + checkAnswer( + sql("SELECT * FROM applySchema1"), + (1, "A1", true, null) :: + (2, "B2", false, null) :: + (3, "C3", true, null) :: + (4, "D4", true, 2147483644) :: Nil) + + checkAnswer( + sql("SELECT f1, f4 FROM applySchema1"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + + val schema2 = StructType( + StructField("f1", StructType( + StructField("f11", IntegerType, false) :: + StructField("f12", BooleanType, false) :: Nil), false) :: + StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) + + val rowRDD2 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) + } + + val schemaRDD2 = applySchema(rowRDD2, schema2) + schemaRDD2.registerAsTable("applySchema2") + checkAnswer( + sql("SELECT * FROM applySchema2"), + (Seq(1, true), Map("A1" -> null)) :: + (Seq(2, false), Map("B2" -> null)) :: + (Seq(3, true), Map("C3" -> null)) :: + (Seq(4, true), Map("D4" -> 2147483644)) :: Nil) + + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + + // The value of a MapType column can be a mutable map. + val rowRDD3 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) + } + + val schemaRDD3 = applySchema(rowRDD3, schema2) + schemaRDD3.registerAsTable("applySchema3") + + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 330b20b315d63..213190e812026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -128,4 +128,11 @@ object TestData { case class TableName(tableName: String) TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName") + + val unparsedStrings = + TestSQLContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala new file mode 100644 index 0000000000000..ff1debff0f8c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.types.util.DataTypeConversions +import org.scalatest.FunSuite + +import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField} +import org.apache.spark.sql.{StructType => SStructType} +import DataTypeConversions._ + +class ScalaSideDataTypeConversionSuite extends FunSuite { + + def checkDataType(scalaDataType: SDataType) { + val javaDataType = asJavaDataType(scalaDataType) + val actual = asScalaDataType(javaDataType) + assert(scalaDataType === actual, s"Converted data type ${actual} " + + s"does not equal the expected data type ${scalaDataType}") + } + + test("convert data types") { + // Simple DataTypes. + checkDataType(org.apache.spark.sql.StringType) + checkDataType(org.apache.spark.sql.BinaryType) + checkDataType(org.apache.spark.sql.BooleanType) + checkDataType(org.apache.spark.sql.TimestampType) + checkDataType(org.apache.spark.sql.DecimalType) + checkDataType(org.apache.spark.sql.DoubleType) + checkDataType(org.apache.spark.sql.FloatType) + checkDataType(org.apache.spark.sql.ByteType) + checkDataType(org.apache.spark.sql.IntegerType) + checkDataType(org.apache.spark.sql.LongType) + checkDataType(org.apache.spark.sql.ShortType) + + // Simple ArrayType. + val simpleScalaArrayType = + org.apache.spark.sql.ArrayType(org.apache.spark.sql.StringType, true) + checkDataType(simpleScalaArrayType) + + // Simple MapType. + val simpleScalaMapType = + org.apache.spark.sql.MapType(org.apache.spark.sql.StringType, org.apache.spark.sql.LongType) + checkDataType(simpleScalaMapType) + + // Simple StructType. + val simpleScalaStructType = SStructType( + SStructField("a", org.apache.spark.sql.DecimalType, false) :: + SStructField("b", org.apache.spark.sql.BooleanType, true) :: + SStructField("c", org.apache.spark.sql.LongType, true) :: + SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) + checkDataType(simpleScalaStructType) + + // Complex StructType. + val complexScalaStructType = SStructType( + SStructField("simpleArray", simpleScalaArrayType, true) :: + SStructField("simpleMap", simpleScalaMapType, true) :: + SStructField("simpleStruct", simpleScalaStructType, true) :: + SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: Nil) + checkDataType(complexScalaStructType) + + // Complex ArrayType. + val complexScalaArrayType = + org.apache.spark.sql.ArrayType(complexScalaStructType, true) + checkDataType(complexScalaArrayType) + + // Complex MapType. + val complexScalaMapType = + org.apache.spark.sql.MapType(complexScalaStructType, complexScalaArrayType, false) + checkDataType(complexScalaMapType) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 35ab14cbc353d..3baa6f8ec0c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -41,7 +41,7 @@ object TestNullableColumnAccessor { class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { + Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index d8898527baa39..dc813fe146c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -37,7 +37,7 @@ object TestNullableColumnBuilder { class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { + Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { testNullableColumnBuilder(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 215618e852eb2..76b1724471442 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite { test("count is partially aggregated") { val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed - val planned = PartialAggregation(query).head - val aggregations = planned.collect { case a: Aggregate => a } + val planned = HashAggregation(query).head + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } assert(aggregations.size === 2) } test("count distinct is not partially aggregated") { val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } test("mixed aggregates are not partially aggregated") { val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index e55648b8ed15a..2cab5e0c44d92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._ * Note: this is only a rough example of how TGFs can be expressed, the final version will likely * involve a lot more sugar for cleaner use in Scala/Java/etc. */ -case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator { +case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { def children = input protected def makeOutput() = 'nameAndAge.string :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index e765cfc83a397..9d9cfdd7c92e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ -protected case class Schema(output: Seq[Attribute]) extends LeafNode - class JsonSuite extends QueryTest { import TestJsonData._ TestJsonData @@ -127,6 +123,18 @@ class JsonSuite extends QueryTest { checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType)) checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType)) checkDataType(ArrayType(IntegerType), StructType(Nil), StringType) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, false), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType)) // StructType checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) @@ -164,16 +172,16 @@ class JsonSuite extends QueryTest { test("Primitive field and type inferring") { val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) - val expectedSchema = - AttributeReference("bigInteger", DecimalType, true)() :: - AttributeReference("boolean", BooleanType, true)() :: - AttributeReference("double", DoubleType, true)() :: - AttributeReference("integer", IntegerType, true)() :: - AttributeReference("long", LongType, true)() :: - AttributeReference("null", StringType, true)() :: - AttributeReference("string", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -192,27 +200,28 @@ class JsonSuite extends QueryTest { test("Complex field and type inferring") { val jsonSchemaRDD = jsonRDD(complexFieldAndType) - val expectedSchema = - AttributeReference("arrayOfArray1", ArrayType(ArrayType(StringType)), true)() :: - AttributeReference("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true)() :: - AttributeReference("arrayOfBigInteger", ArrayType(DecimalType), true)() :: - AttributeReference("arrayOfBoolean", ArrayType(BooleanType), true)() :: - AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() :: - AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() :: - AttributeReference("arrayOfLong", ArrayType(LongType), true)() :: - AttributeReference("arrayOfNull", ArrayType(StringType), true)() :: - AttributeReference("arrayOfString", ArrayType(StringType), true)() :: - AttributeReference("arrayOfStruct", ArrayType( - StructType(StructField("field1", BooleanType, true) :: - StructField("field2", StringType, true) :: Nil)), true)() :: - AttributeReference("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType, true) :: Nil), true)() :: - AttributeReference("structWithArrayFields", StructType( + val expectedSchema = StructType( + StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType), true) :: + StructField("arrayOfInteger", ArrayType(IntegerType), true) :: + StructField("arrayOfLong", ArrayType(LongType), true) :: + StructField("arrayOfNull", ArrayType(StringType, true), true) :: + StructField("arrayOfString", ArrayType(StringType), true) :: + StructField("arrayOfStruct", ArrayType( + StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", StringType, true) :: Nil)), true) :: + StructField("struct", StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType, true) :: Nil), true) :: + StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType), true) :: - StructField("field2", ArrayType(StringType), true) :: Nil), true)() :: Nil + StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -301,15 +310,15 @@ class JsonSuite extends QueryTest { test("Type conflict in primitive field values") { val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - val expectedSchema = - AttributeReference("num_bool", StringType, true)() :: - AttributeReference("num_num_1", LongType, true)() :: - AttributeReference("num_num_2", DecimalType, true)() :: - AttributeReference("num_num_3", DoubleType, true)() :: - AttributeReference("num_str", StringType, true)() :: - AttributeReference("str_bool", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("num_bool", StringType, true) :: + StructField("num_num_1", LongType, true) :: + StructField("num_num_2", DecimalType, true) :: + StructField("num_num_3", DoubleType, true) :: + StructField("num_str", StringType, true) :: + StructField("str_bool", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -426,15 +435,15 @@ class JsonSuite extends QueryTest { test("Type conflict in complex field values") { val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) - val expectedSchema = - AttributeReference("array", ArrayType(IntegerType), true)() :: - AttributeReference("num_struct", StringType, true)() :: - AttributeReference("str_array", StringType, true)() :: - AttributeReference("struct", StructType( - StructField("field", StringType, true) :: Nil), true)() :: - AttributeReference("struct_array", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("array", ArrayType(IntegerType), true) :: + StructField("num_struct", StringType, true) :: + StructField("str_array", StringType, true) :: + StructField("struct", StructType( + StructField("field", StringType, true) :: Nil), true) :: + StructField("struct_array", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -450,12 +459,12 @@ class JsonSuite extends QueryTest { test("Type conflict in array elements") { val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) - val expectedSchema = - AttributeReference("array1", ArrayType(StringType), true)() :: - AttributeReference("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil)), true)() :: Nil + val expectedSchema = StructType( + StructField("array1", ArrayType(StringType, true), true) :: + StructField("array2", ArrayType(StructType( + StructField("field", LongType, true) :: Nil)), true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -475,15 +484,15 @@ class JsonSuite extends QueryTest { test("Handling missing fields") { val jsonSchemaRDD = jsonRDD(missingFields) - val expectedSchema = - AttributeReference("a", BooleanType, true)() :: - AttributeReference("b", LongType, true)() :: - AttributeReference("c", ArrayType(IntegerType), true)() :: - AttributeReference("d", StructType( - StructField("field", BooleanType, true) :: Nil), true)() :: - AttributeReference("e", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("a", BooleanType, true) :: + StructField("b", LongType, true) :: + StructField("c", ArrayType(IntegerType), true) :: + StructField("d", StructType( + StructField("field", BooleanType, true) :: Nil), true) :: + StructField("e", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") } @@ -494,16 +503,16 @@ class JsonSuite extends QueryTest { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val jsonSchemaRDD = jsonFile(path) - val expectedSchema = - AttributeReference("bigInteger", DecimalType, true)() :: - AttributeReference("boolean", BooleanType, true)() :: - AttributeReference("double", DoubleType, true)() :: - AttributeReference("integer", IntegerType, true)() :: - AttributeReference("long", LongType, true)() :: - AttributeReference("null", StringType, true)() :: - AttributeReference("string", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -518,4 +527,53 @@ class JsonSuite extends QueryTest { "this is a simple string.") :: Nil ) } + + test("Applying schemas") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val schema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + val jsonSchemaRDD1 = jsonFile(path, schema) + + assert(schema === jsonSchemaRDD1.schema) + + jsonSchemaRDD1.registerAsTable("jsonTable1") + + checkAnswer( + sql("select * from jsonTable1"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + + val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) + + assert(schema === jsonSchemaRDD2.schema) + + jsonSchemaRDD2.registerAsTable("jsonTable2") + + checkAnswer( + sql("select * from jsonTable2"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3c911e9a4e7b1..8955455ec98c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job + import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -207,18 +209,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection of simple Parquet file") { - val scanner = new ParquetTableScan( - ParquetTestData.testData.output, - ParquetTestData.testData, - Seq())(TestSQLContext) - val projected = scanner.pruneColumns(ParquetTypesConverter - .convertToAttributes(MessageTypeParser - .parseMessageType(ParquetTestData.subTestSchema))) - assert(projected.output.size === 2) - val result = projected - .execute() - .map(_.copy()) - .collect() + val result = ParquetTestData.testData.select('myboolean, 'mylong).collect() result.zipWithIndex.foreach { case (row, index) => { if (index % 3 == 0) diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml new file mode 100644 index 0000000000000..7fac90fdc596d --- /dev/null +++ b/sql/hive-thriftserver/pom.xml @@ -0,0 +1,82 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-hive-thriftserver_2.10 + jar + Spark Project Hive + http://spark.apache.org/ + + hive-thriftserver + + + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + + + org.spark-project.hive + hive-cli + ${hive.version} + + + org.spark-project.hive + hive-jdbc + ${hive.version} + + + org.spark-project.hive + hive-beeline + ${hive.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala new file mode 100644 index 0000000000000..ddbc2a79fb512 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService +import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +/** + * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a + * `HiveThriftServer2` thrift server. + */ +private[hive] object HiveThriftServer2 extends Logging { + var LOG = LogFactory.getLog(classOf[HiveServer2]) + + def main(args: Array[String]) { + val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + + if (!optionsProcessor.process(args)) { + logger.warn("Error starting HiveThriftServer2 with given arguments") + System.exit(-1) + } + + val ss = new SessionState(new HiveConf(classOf[SessionState])) + + // Set all properties specified via command line. + val hiveConf: HiveConf = ss.getConf + hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => + logger.debug(s"HiveConf var: $k=$v") + } + + SessionState.start(ss) + + logger.info("Starting SparkContext") + SparkSQLEnv.init() + SessionState.start(ss) + + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.sparkContext.stop() + } + } + ) + + try { + val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) + server.init(hiveConf) + server.start() + logger.info("HiveThriftServer2 started") + } catch { + case e: Exception => + logger.error("Error starting HiveThriftServer2", e) + System.exit(-1) + } + } +} + +private[hive] class HiveThriftServer2(hiveContext: HiveContext) + extends HiveServer2 + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + setSuperField(this, "cliService", sparkSqlCliService) + addService(sparkSqlCliService) + + val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService) + setSuperField(this, "thriftCLIService", thriftCliService) + addService(thriftCliService) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala new file mode 100644 index 0000000000000..599294dfbb7d7 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] object ReflectionUtils { + def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { + setAncestorField(obj, 1, fieldName, fieldValue) + } + + def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) { + val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.set(obj, fieldValue) + } + + def getSuperField[T](obj: AnyRef, fieldName: String): T = { + getAncestorField[T](obj, 1, fieldName) + } + + def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = { + val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(clazz).asInstanceOf[T] + } + + def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = { + invoke(clazz, null, methodName, args: _*) + } + + def invoke( + clazz: Class[_], + obj: AnyRef, + methodName: String, + args: (Class[_], AnyRef)*): AnyRef = { + + val (types, values) = args.unzip + val method = clazz.getDeclaredMethod(methodName, types: _*) + method.setAccessible(true) + method.invoke(obj, values.toSeq: _*) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala new file mode 100755 index 0000000000000..cb17d7ce58ea0 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io._ +import java.util.{ArrayList => JArrayList} + +import jline.{ConsoleReader, History} +import org.apache.commons.lang.StringUtils +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.thrift.transport.TSocket + +import org.apache.spark.sql.Logging + +private[hive] object SparkSQLCLIDriver { + private var prompt = "spark-sql" + private var continuedPrompt = "".padTo(prompt.length, ' ') + private var transport:TSocket = _ + + installSignalHandler() + + /** + * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(), + * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while + * a command is being processed by the current thread. + */ + def installSignalHandler() { + HiveInterruptUtils.add(new HiveInterruptCallback { + override def interrupt() { + // Handle remote execution mode + if (SparkSQLEnv.sparkContext != null) { + SparkSQLEnv.sparkContext.cancelAllJobs() + } else { + if (transport != null) { + // Force closing of TCP connection upon session termination + transport.getSocket.close() + } + } + } + }) + } + + def main(args: Array[String]) { + val oproc = new OptionsProcessor() + if (!oproc.process_stage1(args)) { + System.exit(1) + } + + // NOTE: It is critical to do this here so that log4j is reinitialized + // before any of the other core hive classes are loaded + var logInitFailed = false + var logInitDetailMessage: String = null + try { + logInitDetailMessage = LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => + logInitFailed = true + logInitDetailMessage = e.getMessage + } + + val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) + + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + if (!oproc.process_stage2(sessionState)) { + System.exit(2) + } + + if (!sessionState.getIsSilent) { + if (logInitFailed) System.err.println(logInitDetailMessage) + else SessionState.getConsole.printInfo(logInitDetailMessage) + } + + // Set all properties specified via command line. + val conf: HiveConf = sessionState.getConf + sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => + conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + sessionState.getOverriddenConfigurations.put( + item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + } + + SessionState.start(sessionState) + + // Clean up after we exit + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.stop() + } + } + ) + + // "-h" option has been passed, so connect to Hive thrift server. + if (sessionState.getHost != null) { + sessionState.connect() + if (sessionState.isRemoteMode) { + prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt + continuedPrompt = "".padTo(prompt.length, ' ') + } + } + + if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) { + // Hadoop-20 and above - we need to augment classpath using hiveconf + // components. + // See also: code in ExecDriver.java + var loader = conf.getClassLoader + val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) + if (StringUtils.isNotBlank(auxJars)) { + loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ",")) + } + conf.setClassLoader(loader) + Thread.currentThread().setContextClassLoader(loader) + } + + val cli = new SparkSQLCLIDriver + cli.setHiveVariables(oproc.getHiveVariables) + + // TODO work around for set the log output to console, because the HiveContext + // will set the output into an invalid buffer. + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + // Execute -i init files (always in silent mode) + cli.processInitFiles(sessionState) + + if (sessionState.execString != null) { + System.exit(cli.processLine(sessionState.execString)) + } + + try { + if (sessionState.fileName != null) { + System.exit(cli.processFile(sessionState.fileName)) + } + } catch { + case e: FileNotFoundException => + System.err.println(s"Could not open input file for reading. (${e.getMessage})") + System.exit(3) + } + + val reader = new ConsoleReader() + reader.setBellEnabled(false) + // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) + CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + + val historyDirectory = System.getProperty("user.home") + + try { + if (new File(historyDirectory).exists()) { + val historyFile = historyDirectory + File.separator + ".hivehistory" + reader.setHistory(new History(new File(historyFile))) + } else { + System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + " does not exist. History will not be available during this session.") + } + } catch { + case e: Exception => + System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + "history file. History will not be available during this session.") + System.err.println(e.getMessage) + } + + val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") + clientTransportTSocketField.setAccessible(true) + + transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] + + var ret = 0 + var prefix = "" + val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", + classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) + + def promptWithCurrentDB = s"$prompt$currentDB" + def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + + var currentPrompt = promptWithCurrentDB + var line = reader.readLine(currentPrompt + "> ") + + while (line != null) { + if (prefix.nonEmpty) { + prefix += '\n' + } + + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } + + line = reader.readLine(currentPrompt + "> ") + } + + sessionState.close() + + System.exit(ret) + } +} + +private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { + private val sessionState = SessionState.get().asInstanceOf[CliSessionState] + + private val LOG = LogFactory.getLog("CliDriver") + + private val console = new SessionState.LogHelper(LOG) + + private val conf: Configuration = + if (sessionState != null) sessionState.getConf else new Configuration() + + // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver + // because the Hive unit tests do not go through the main() code path. + if (!sessionState.isRemoteMode) { + SparkSQLEnv.init() + } + + override def processCmd(cmd: String): Int = { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + if (cmd_trimmed.toLowerCase.equals("quit") || + cmd_trimmed.toLowerCase.equals("exit") || + tokens(0).equalsIgnoreCase("source") || + cmd_trimmed.startsWith("!") || + tokens(0).toLowerCase.equals("list") || + sessionState.isRemoteMode) { + val start = System.currentTimeMillis() + super.processCmd(cmd) + val end = System.currentTimeMillis() + val timeTaken: Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds") + 0 + } else { + var ret = 0 + val hconf = conf.asInstanceOf[HiveConf] + val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) + + if (proc != null) { + if (proc.isInstanceOf[Driver]) { + val driver = new SparkSQLDriver + + driver.init() + val out = sessionState.out + val start:Long = System.currentTimeMillis() + if (sessionState.getIsVerbose) { + out.println(cmd) + } + + val rc = driver.run(cmd) + ret = rc.getResponseCode + if (ret != 0) { + console.printError(rc.getErrorMessage()) + driver.close() + return ret + } + + val res = new JArrayList[String]() + + if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { + // Print the column names. + Option(driver.getSchema.getFieldSchemas).map { fields => + out.println(fields.map(_.getName).mkString("\t")) + } + } + + try { + while (!out.checkError() && driver.getResults(res)) { + res.foreach(out.println) + res.clear() + } + } catch { + case e:IOException => + console.printError( + s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + ret = 1 + } + + val cret = driver.close() + if (ret == 0) { + ret = cret + } + + val end = System.currentTimeMillis() + if (end > start) { + val timeTaken:Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds", null) + } + + // Destroy the driver to release all the locks. + driver.destroy() + } else { + if (sessionState.getIsVerbose) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } + ret = proc.run(cmd_1).getResponseCode + } + } + ret + } + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala new file mode 100644 index 0000000000000..42cbf363b274f --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io.IOException +import java.util.{List => JList} +import javax.security.auth.login.LoginException + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hive.service.Service.STATE +import org.apache.hive.service.auth.HiveAuthFactory +import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.{AbstractService, Service, ServiceException} + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +private[hive] class SparkSQLCLIService(hiveContext: HiveContext) + extends CLIService + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + setSuperField(this, "sessionManager", sparkSqlSessionManager) + addService(sparkSqlSessionManager) + + try { + HiveAuthFactory.loginFromKeytab(hiveConf) + val serverUserName = ShimLoader.getHadoopShims + .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf)) + setSuperField(this, "serverUserName", serverUserName) + } catch { + case e @ (_: IOException | _: LoginException) => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } + + initCompositeService(hiveConf) + } +} + +private[thriftserver] trait ReflectedCompositeService { this: AbstractService => + def initCompositeService(hiveConf: HiveConf) { + // Emulating `CompositeService.init(hiveConf)` + val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") + serviceList.foreach(_.init(hiveConf)) + + // Emulating `AbstractService.init(hiveConf)` + invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) + setAncestorField(this, 3, "hiveConf", hiveConf) + invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED) + getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 0000000000000..a56b19a4bcda0 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.util.{ArrayList => JArrayList} + +import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} + +private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver with Logging { + + private var tableSchema: Schema = _ + private var hiveResponse: Seq[String] = _ + + override def init(): Unit = { + } + + private def getResultSetSchema(query: context.QueryExecution): Schema = { + val analyzed = query.analyzed + logger.debug(s"Result Schema: ${analyzed.output}") + if (analyzed.output.size == 0) { + new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + } else { + val fieldSchemas = analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + + new Schema(fieldSchemas, null) + } + } + + override def run(command: String): CommandProcessorResponse = { + // TODO unify the error code + try { + val execution = context.executePlan(context.hql(command).logicalPlan) + hiveResponse = execution.stringResult() + tableSchema = getResultSetSchema(execution) + new CommandProcessorResponse(0) + } catch { + case cause: Throwable => + logger.error(s"Failed in [$command]", cause) + new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) + } + } + + override def close(): Int = { + hiveResponse = null + tableSchema = null + 0 + } + + override def getSchema: Schema = tableSchema + + override def getResults(res: JArrayList[String]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.addAll(hiveResponse) + hiveResponse = null + true + } + } + + override def destroy() { + super.destroy() + hiveResponse = null + tableSchema = null + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala new file mode 100644 index 0000000000000..451c3bd7b9352 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.ql.session.SessionState + +import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.{SparkConf, SparkContext} + +/** A singleton object for the master program. The slaves should not access this. */ +private[hive] object SparkSQLEnv extends Logging { + logger.debug("Initializing SparkSQLEnv") + + var hiveContext: HiveContext = _ + var sparkContext: SparkContext = _ + + def init() { + if (hiveContext == null) { + sparkContext = new SparkContext(new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + + sparkContext.addSparkListener(new StatsReportListener()) + + hiveContext = new HiveContext(sparkContext) { + @transient override lazy val sessionState = SessionState.get() + @transient override lazy val hiveconf = sessionState.getConf + } + } + } + + /** Cleans up and shuts down the Spark SQL environments. */ + def stop() { + logger.debug("Shutting down Spark SQL Environment") + // Stop the SparkContext + if (SparkSQLEnv.sparkContext != null) { + sparkContext.stop() + sparkContext = null + hiveContext = null + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000000..6b3275b4eaf04 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.session.SessionManager + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala new file mode 100644 index 0000000000000..a4e1f3e762e89 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.server + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.math.{random, round} + +import java.sql.Timestamp +import java.util.{Map => JMap} + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} + +/** + * Executes queries using Spark SQL, and maintains a list of handles to active queries. + */ +class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging { + val handleToOperation = ReflectionUtils + .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + + override def newExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + async: Boolean): ExecuteStatementOperation = synchronized { + + val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) { + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logger.debug("CLOSING") + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + if (!iter.hasNext) { + new RowSet() + } else { + val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No. + var curRow = 0 + var rowSet = new ArrayBuffer[Row](maxRows) + + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = new Row() + var curCol = 0 + + while (curCol < sparkRow.length) { + dataTypes(curCol) match { + case StringType => + row.addString(sparkRow(curCol).asInstanceOf[String]) + case IntegerType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) + case BooleanType => + row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) + case DoubleType => + row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) + case FloatType => + row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) + case DecimalType => + val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal + row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) + case ByteType => + row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) + case ShortType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) + case TimestampType => + row.addColumnValue( + ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) + row.addColumnValue(ColumnValue.stringValue(hiveString)) + } + curCol += 1 + } + rowSet += row + curRow += 1 + } + new RowSet(rowSet, 0) + } + } + + def getResultSetSchema: TableSchema = { + logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + logger.info(s"Running query '$statement'") + setState(OperationState.RUNNING) + try { + result = hiveContext.hql(statement) + logger.debug(result.queryExecution.toString()) + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + iter = result.queryExecution.toRdd.toLocalIterator + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logger.error("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + } + } + + handleToOperation.put(operation.getHandle, operation) + operation + } +} diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt new file mode 100644 index 0000000000000..850f8014b6f05 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt @@ -0,0 +1,5 @@ +238val_238 +86val_86 +311val_311 +27val_27 +165val_165 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala new file mode 100644 index 0000000000000..69f19f826a802 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, InputStreamReader, PrintWriter} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { + val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli") + val METASTORE_PATH = TestUtils.getMetastorePath("cli") + + override def beforeAll() { + val pb = new ProcessBuilder( + "../../bin/spark-sql", + "--master", + "local", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) + + process = pb.start() + outputWriter = new PrintWriter(process.getOutputStream, true) + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "spark-sql>") + } + + override def afterAll() { + process.destroy() + process.waitFor() + } + + test("simple commands") { + val dataFilePath = getDataFile("data/files/small_kv.txt") + executeQuery("create table hive_test1(key int, val string);") + executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;") + executeQuery("cache table hive_test1", "Time taken") + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala new file mode 100644 index 0000000000000..fe3403b3292ec --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent._ + +import java.io.{BufferedReader, InputStreamReader} +import java.net.ServerSocket +import java.sql.{Connection, DriverManager, Statement} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.util.getTempFilePath + +/** + * Test for the HiveThriftServer2 using JDBC. + */ +class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging { + + val WAREHOUSE_PATH = getTempFilePath("warehouse") + val METASTORE_PATH = getTempFilePath("metastore") + + val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver" + val TABLE = "test" + val HOST = "localhost" + val PORT = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } + + // If verbose is true, the test program will print all outputs coming from the Hive Thrift server. + val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean + + Class.forName(DRIVER_NAME) + + override def beforeAll() { launchServer() } + + override def afterAll() { stopServer() } + + private def launchServer(args: Seq[String] = Seq.empty) { + // Forking a new process to start the Hive Thrift server. The reason to do this is it is + // hard to clean up Hive resources entirely, so we just start a new process and kill + // that process for cleanup. + val defaultArgs = Seq( + "../../sbin/start-thriftserver.sh", + "--master local", + "--hiveconf", + "hive.root.logger=INFO,console", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") + val pb = new ProcessBuilder(defaultArgs ++ args) + val environment = pb.environment() + environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) + environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) + process = pb.start() + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "ThriftBinaryCLIService listening on") + + // Spawn a thread to read the output from the forked process. + // Note that this is necessary since in some configurations, log4j could be blocked + // if its output to stderr are not read, and eventually blocking the entire test suite. + future { + while (true) { + val stdout = readFrom(inputReader) + val stderr = readFrom(errorReader) + if (VERBOSE && stdout.length > 0) { + println(stdout) + } + if (VERBOSE && stderr.length > 0) { + println(stderr) + } + Thread.sleep(50) + } + } + } + + private def stopServer() { + process.destroy() + process.waitFor() + } + + test("test query execution against a Hive Thrift server") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test") + stmt.execute("DROP TABLE IF EXISTS test_cached") + stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") + stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CACHE TABLE test_cached") + + var rs = stmt.executeQuery("select count(*) from test") + rs.next() + assert(rs.getInt(1) === 5) + + rs = stmt.executeQuery("select count(*) from test_cached") + rs.next() + assert(rs.getInt(1) === 4) + + stmt.close() + } + + def getConnection: Connection = { + val connectURI = s"jdbc:hive2://localhost:$PORT/" + DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") + } + + def createStatement(): Statement = getConnection.createStatement() +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala new file mode 100644 index 0000000000000..bb2242618fbef --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, PrintWriter} +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.hive.common.LogUtils +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException + +object TestUtils { + val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss") + + def getWarehousePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" + + timestamp.format(new Date) + } + + def getMetastorePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" + + timestamp.format(new Date) + } + + // Dummy function for initialize the log4j properties. + def init() { } + + // initialize log4j + try { + LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => // Ignore the error. + } +} + +trait TestUtils { + var process : Process = null + var outputWriter : PrintWriter = null + var inputReader : BufferedReader = null + var errorReader : BufferedReader = null + + def executeQuery( + cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = { + println("Executing: " + cmd + ", expecting output: " + outputMessage) + outputWriter.write(cmd + "\n") + outputWriter.flush() + waitForQuery(timeout, outputMessage) + } + + protected def waitForQuery(timeout: Long, message: String): String = { + if (waitForOutput(errorReader, message, timeout)) { + Thread.sleep(500) + readOutput() + } else { + assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput()) + null + } + } + + // Wait for the specified str to appear in the output. + protected def waitForOutput( + reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = { + val startTime = System.currentTimeMillis + var out = "" + while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) { + out += readFrom(reader) + } + out.contains(str) + } + + // Read stdout output and filter out garbage collection messages. + protected def readOutput(): String = { + val output = readFrom(inputReader) + // Remove GC Messages + val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC")) + .mkString("\n") + filteredOutput + } + + protected def readFrom(reader: BufferedReader): String = { + var out = "" + var c = 0 + while (reader.ready) { + c = reader.read() + out += c.asInstanceOf[Char] + } + out + } + + protected def getDataFile(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(name) + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c69e93ba2b9ba..4fef071161719 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -52,7 +52,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { override def blackList = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", - "hook_context", + "hook_context_cs", "mapjoin_hook", "multi_sahooks", "overridden_confs", @@ -289,7 +289,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "compute_stats_empty_table", "compute_stats_long", "compute_stats_string", - "compute_stats_table", "convert_enum_to_string", "correlationoptimizer1", "correlationoptimizer10", @@ -395,7 +394,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_sort_9", "groupby_sort_test_1", "having", - "having1", "implicit_cast1", "innerjoin", "inoutdriver", @@ -697,8 +695,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf7", "udf8", "udf9", - "udf_E", - "udf_PI", "udf_abs", "udf_acos", "udf_add", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 1699ffe06ce15..93d00f7c37c9b 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -32,7 +32,7 @@ Spark Project Hive http://spark.apache.org/ - hive + hive diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 201c85f3d501e..7e3b8727bebed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -37,15 +37,17 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand /** - * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is - * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. + * DEPRECATED: Use HiveContext instead. */ +@deprecated(""" + Use HiveContext instead. It will still create a local metastore if one is not specified. + However, note that the default directory is ./metastore_db, not ./metastore + """, "1.1") class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val metastorePath = new File("metastore").getCanonicalPath @@ -129,12 +131,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + + ss.err = new PrintStream(outputBuffer, true, "UTF-8") + ss.out = new PrintStream(outputBuffer, true, "UTF-8") + ss } - sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") - sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") - override def set(key: String, value: String): Unit = { super.set(key, value) runSqlHive(s"SET $key=$value") @@ -231,7 +234,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveTableScans, DataSinks, Scripts, - PartialAggregation, + HashAggregation, LeftSemiJoin, HashJoin, BasicOperators, @@ -255,14 +258,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DecimalType, TimestampType, BinaryType) - protected def toHiveString(a: (Any, DataType)): String = a match { + protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ)) => + case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType)) => + case (map: Map[_,_], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -279,9 +282,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ)) => + case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType)) => + case (map: Map[_,_], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index ad7dc0ecdb1bf..354fcd53f303b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -152,8 +152,9 @@ private[hive] trait HiveInspectors { } def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType) => + case ArrayType(tpe, _) => + ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType, _) => ObjectInspectorFactory.getStandardMapObjectInspector( toInspector(keyType), toInspector(valueType)) case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 156b090712df2..fa4e78439c26c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,15 +19,16 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.Logging +import org.apache.spark.sql.{SQLContext, Logging} import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical @@ -64,9 +65,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Since HiveQL is case insensitive for table names we make them all lowercase. MetastoreRelation( - databaseName, - tblName, - alias)(table.getTTable, partitions.map(part => part.getTPartition)) + databaseName, tblName, alias)( + table.getTTable, partitions.map(part => part.getTPartition))(hive) } def createTable( @@ -200,7 +200,9 @@ object HiveMetastoreTypes extends RegexParsers { "varchar\\((\\d+)\\)".r ^^^ StringType protected lazy val arrayType: Parser[DataType] = - "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType + "array" ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } protected lazy val mapType: Parser[DataType] = "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { @@ -229,10 +231,10 @@ object HiveMetastoreTypes extends RegexParsers { } def toMetastoreType(dt: DataType): String = dt match { - case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>" + case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" - case MapType(keyType, valueType) => + case MapType(keyType, valueType, _) => s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" case StringType => "string" case FloatType => "float" @@ -251,7 +253,11 @@ object HiveMetastoreTypes extends RegexParsers { private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) - extends BaseRelation { + (@transient sqlContext: SQLContext) + extends LeafNode { + + self: Product => + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. // Right now, using org.apache.hadoop.hive.ql.metadata.Table and @@ -264,6 +270,21 @@ private[hive] case class MetastoreRelation new Partition(hiveQlTable, p) } + @transient override lazy val statistics = Statistics( + sizeInBytes = { + // TODO: check if this estimate is valid for tables after partition pruning. + // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be + // relatively cheap if parameters for the table are populated into the metastore. An + // alternative would be going through Hadoop's FileSystem API, which can be expensive if a lot + // of RPCs are involved. Besides `totalSize`, there are also `numFiles`, `numRows`, + // `rawDataSize` keys that we can look at in the future. + BigInt( + Option(hiveQlTable.getParameters.get("totalSize")) + .map(_.toLong) + .getOrElse(sqlContext.defaultSizeInBytes)) + } + ) + val tableDesc = new TableDesc( Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, @@ -275,14 +296,14 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: FieldSchema) { - def toAttribute = AttributeReference( - f.getName, - HiveMetastoreTypes.toDataType(f.getType), - // Since data can be dumped in randomly with no validation, everything is nullable. - nullable = true - )(qualifiers = tableName +: alias.toSeq) - } + implicit class SchemaAttribute(f: FieldSchema) { + def toAttribute = AttributeReference( + f.getName, + HiveMetastoreTypes.toDataType(f.getType), + // Since data can be dumped in randomly with no validation, everything is nullable. + nullable = true + )(qualifiers = tableName +: alias.toSeq) + } // Must be a stable value since new attributes are born here. val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e6ab68b563f8d..3d2eb1eefaeda 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -44,6 +44,8 @@ private[hive] case class SourceCommand(filePath: String) extends Command private[hive] case class AddFile(filePath: String) extends Command +private[hive] case class DropTable(tableName: String, ifExists: Boolean) extends Command + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( @@ -96,7 +98,6 @@ private[hive] object HiveQl { "TOK_CREATEINDEX", "TOK_DROPDATABASE", "TOK_DROPINDEX", - "TOK_DROPTABLE", "TOK_MSCK", // TODO(marmbrus): Figure out how view are expanded by hive, as we might need to handle this. @@ -377,6 +378,12 @@ private[hive] object HiveQl { } protected def nodeToPlan(node: Node): LogicalPlan = node match { + // Special drop table that also uncaches. + case Token("TOK_DROPTABLE", + Token("TOK_TABNAME", tableNameParts) :: + ifExists) => + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + DropTable(tableName, ifExists.nonEmpty) // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => @@ -610,7 +617,7 @@ private[hive] object HiveQl { // TOK_DESTINATION means to overwrite the table. val resultDestination = (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = if (intoClause.isEmpty) true else false + val overwrite = intoClause.isEmpty nodeToDest( resultDestination, withLimit, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 4d0fab4140b21..2175c5f3835a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -81,6 +81,8 @@ private[hive] trait HiveStrategies { case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil + case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed resolvedTable match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c3942578d6b5a..82c88280d7754 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -24,6 +24,8 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} @@ -31,13 +33,16 @@ import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast} +import org.apache.spark.sql.catalyst.types.DataType + /** * A trait for subclasses that handle table scans. */ private[hive] sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[_] + def makeRDDForTable(hiveTable: HiveTable): RDD[Row] - def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] } @@ -46,7 +51,10 @@ private[hive] sealed trait TableReader { * data warehouse directory. */ private[hive] -class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext) +class HadoopTableReader( + @transient attributes: Seq[Attribute], + @transient relation: MetastoreRelation, + @transient sc: HiveContext) extends TableReader { // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless @@ -63,10 +71,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon def hiveConf = _broadcastedHiveConf.value.value - override def makeRDDForTable(hiveTable: HiveTable): RDD[_] = + override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, - _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -81,14 +89,14 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon def makeRDDForTable( hiveTable: HiveTable, deserializerClass: Class[_ <: Deserializer], - filterOpt: Option[PathFilter]): RDD[_] = { + filterOpt: Option[PathFilter]): RDD[Row] = { assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. - val tableDesc = _tableDesc + val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val tablePath = hiveTable.getPath @@ -99,23 +107,20 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + val attrsWithIndex = attributes.zipWithIndex + val mutableRow = new GenericMutableRow(attrsWithIndex.length) val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - // Deserialize each Writable to get the row value. - iter.map { - case v: Writable => deserializer.deserialize(v) - case value => - sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}") - } + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) } deserializedHadoopRDD } - override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = { + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = { val partitionToDeserializer = partitions.map(part => (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) @@ -132,9 +137,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon * subdirectory of each partition being read. If None, then all files are accepted. */ def makeRDDForPartitionedTable( - partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[_] = { - + partitionToDeserializer: Map[HivePartition, + Class[_ <: Deserializer]], + filterOpt: Option[PathFilter]): RDD[Row] = { val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = partition.getPartitionPath @@ -156,33 +161,42 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon } // Create local references so that the outer object isn't serialized. - val tableDesc = _tableDesc + val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val localDeserializer = partDeserializer + val mutableRow = new GenericMutableRow(attributes.length) + + // split the attributes (output schema) into 2 categories: + // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the + // index of the attribute in the output Row. + val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => { + relation.partitionKeys.indexOf(attr._1) >= 0 + }) + + def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = { + partitionKeys.foreach { case (attr, ordinal) => + // get partition key ordinal for a given attribute + val partOridinal = relation.partitionKeys.indexOf(attr) + row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null) + } + } + // fill the partition key for the given MutableRow Object + fillPartitionKeys(partValues, mutableRow) val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) hivePartitionRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value - val rowWithPartArr = new Array[Object](2) - - // The update and deserializer initialization are intentionally - // kept out of the below iter.map loop to save performance. - rowWithPartArr.update(1, partValues) val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - // Map each tuple to a row object - iter.map { value => - val deserializedRow = deserializer.deserialize(value) - rowWithPartArr.update(0, deserializedRow) - rowWithPartArr.asInstanceOf[Object] - } + // fill the non partition key attributes + HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow) } }.toSeq // Even if we don't use any partitions, we still need an empty RDD if (hivePartitionRDDs.size == 0) { - new EmptyRDD[Object](sc.sparkContext) + new EmptyRDD[Row](sc.sparkContext) } else { new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } @@ -225,10 +239,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon // Only take the value (skip the key) because Hive works only with values. rdd.map(_._2) } - } -private[hive] object HadoopTableReader { +private[hive] object HadoopTableReader extends HiveInspectors { /** * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to * instantiate a HadoopRDD. @@ -241,4 +254,40 @@ private[hive] object HadoopTableReader { val bufferSize = System.getProperty("spark.buffer.size", "65536") jobConf.set("io.file.buffer.size", bufferSize) } + + /** + * Transform the raw data(Writable object) into the Row object for an iterable input + * @param iter Iterable input which represented as Writable object + * @param deserializer Deserializer associated with the input writable object + * @param attrs Represents the row attribute names and its zero-based position in the MutableRow + * @param row reusable MutableRow object + * + * @return Iterable Row object that transformed from the given iterable input. + */ + def fillObject( + iter: Iterator[Writable], + deserializer: Deserializer, + attrs: Seq[(Attribute, Int)], + row: GenericMutableRow): Iterator[Row] = { + val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] + // get the field references according to the attributes(output of the reader) required + val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) } + + // Map each tuple to a row object + iter.map { value => + val raw = deserializer.deserialize(value) + var idx = 0; + while (idx < fieldRefs.length) { + val fieldRef = fieldRefs(idx)._1 + val fieldIdx = fieldRefs(idx)._2 + val fieldValue = soi.getStructFieldData(raw, fieldRef) + + row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector()) + + idx += 1 + } + + row: Row + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 9386008d02d51..c50e8c4b5c5d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -53,15 +53,24 @@ object TestHive * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of * test cases that rely on TestHive must be serialized. */ -class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { +class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. System.clearProperty("spark.hostPort") - override lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath - override lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath + lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + + /** Sets up the system initially or after a RESET command */ + protected def configure() { + set("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$metastorePath;create=true") + set("hive.metastore.warehouse.dir", warehousePath) + } + + configure() // Must be called before initializing the catalog below. /** The location of the compiled hive distribution */ lazy val hiveHome = envVarToFile("HIVE_HOME") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala new file mode 100644 index 0000000000000..9cd0c86c6c796 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.execution.{Command, LeafNode} +import org.apache.spark.sql.hive.HiveContext + +/** + * :: DeveloperApi :: + * Drops a table from the metastore and removes it if it is cached. + */ +@DeveloperApi +case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with Command { + + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + def output = Seq.empty + + override protected[sql] lazy val sideEffectResult: Seq[Any] = { + val ifExistsClause = if (ifExists) "IF EXISTS " else "" + hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") + hiveContext.catalog.unregisterTable(None, tableName) + Seq.empty + } + + override def execute(): RDD[Row] = { + sideEffectResult + sparkContext.emptyRDD[Row] + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index e7016fa16eea9..8920e2a76a27f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ -import org.apache.spark.util.MutablePair /** * :: DeveloperApi :: @@ -50,8 +49,7 @@ case class HiveTableScan( relation: MetastoreRelation, partitionPruningPred: Option[Expression])( @transient val context: HiveContext) - extends LeafNode - with HiveInspectors { + extends LeafNode { require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") @@ -67,42 +65,7 @@ case class HiveTableScan( } @transient - private[this] val hadoopReader = new HadoopTableReader(relation.tableDesc, context) - - /** - * The hive object inspector for this table, which can be used to extract values from the - * serialized row representation. - */ - @transient - private[this] lazy val objectInspector = - relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] - - /** - * Functions that extract the requested attributes from the hive output. Partitioned values are - * casted from string to its declared data type. - */ - @transient - protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { - attributes.map { a => - val ordinal = relation.partitionKeys.indexOf(a) - if (ordinal >= 0) { - val dataType = relation.partitionKeys(ordinal).dataType - (_: Any, partitionKeys: Array[String]) => { - castFromString(partitionKeys(ordinal), dataType) - } - } else { - val ref = objectInspector.getAllStructFieldRefs - .find(_.getFieldName == a.name) - .getOrElse(sys.error(s"Can't find attribute $a")) - val fieldObjectInspector = ref.getFieldObjectInspector - - (row: Any, _: Array[String]) => { - val data = objectInspector.getStructFieldData(row, ref) - unwrapData(data, fieldObjectInspector) - } - } - } - } + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context) private[this] def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) @@ -114,6 +77,7 @@ case class HiveTableScan( val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") if (attributes.size == relation.output.size) { + // SQLContext#pruneFilterProject guarantees no duplicated value in `attributes` ColumnProjectionUtils.setFullyReadColumns(hiveConf) } else { ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) @@ -140,12 +104,6 @@ case class HiveTableScan( addColumnMetadataToConf(context.hiveconf) - private def inputRdd = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) - } - /** * Prunes partitions not involve the query plan. * @@ -169,44 +127,10 @@ case class HiveTableScan( } } - override def execute() = { - inputRdd.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val mutableRow = new GenericMutableRow(attributes.length) - val mutablePair = new MutablePair[Any, Array[String]]() - val buffered = iterator.buffered - - // NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern - // matching are avoided intentionally. - val rowsAndPartitionKeys = buffered.head match { - // With partition keys - case _: Array[Any] => - buffered.map { case array: Array[Any] => - val deserializedRow = array(0) - val partitionKeys = array(1).asInstanceOf[Array[String]] - mutablePair.update(deserializedRow, partitionKeys) - } - - // Without partition keys - case _ => - val emptyPartitionKeys = Array.empty[String] - buffered.map { deserializedRow => - mutablePair.update(deserializedRow, emptyPartitionKeys) - } - } - - rowsAndPartitionKeys.map { pair => - var i = 0 - while (i < attributes.length) { - mutableRow(i) = attributeFunctions(i)(pair._1, pair._2) - i += 1 - } - mutableRow: Row - } - } - } + override def execute() = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } override def output = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c2b0b00aa5852..39033bdeac4b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -131,7 +131,7 @@ case class InsertIntoHiveTable( conf, SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) - logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) writer.preSetup() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8258ee5fef0eb..0c8f676e9c5c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -67,7 +67,7 @@ case class ScriptTransformation( } } readerThread.start() - val outputProjection = new Projection(input) + val outputProjection = new InterpretedProjection(input, child.output) iter .map(outputProjection) // TODO: Use SerDe diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 057eb60a02612..7582b4743d404 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -251,8 +251,10 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val function: GenericUDTF = createFunction() + @transient protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + @transient protected lazy val outputInspectors = { val structInspector = function.initialize(inputInspectors.toArray) structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) @@ -278,7 +280,7 @@ private[hive] case class HiveGenericUdtf( override def eval(input: Row): TraversableOnce[Row] = { outputInspectors // Make sure initialized. - val inputProjection = new Projection(children) + val inputProjection = new InterpretedProjection(children) val collector = new UDTFCollector function.setCollector(collector) @@ -332,7 +334,7 @@ private[hive] case class HiveUdafFunction( override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new Projection(exprs) + val inputProjection = new InterpretedProjection(exprs) def update(input: Row): Unit = { val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray diff --git a/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d new file mode 100644 index 0000000000000..00750edc07d64 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 new file mode 100644 index 0000000000000..00750edc07d64 --- /dev/null +++ b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 b/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 new file mode 100644 index 0000000000000..3f2cab688ccc2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 @@ -0,0 +1,136 @@ +0 +5 +12 +15 +18 +24 +26 +35 +37 +42 +51 +58 +67 +70 +72 +76 +83 +84 +90 +95 +97 +98 +100 +103 +104 +113 +118 +119 +120 +125 +128 +129 +134 +137 +138 +146 +149 +152 +164 +165 +167 +169 +172 +174 +175 +176 +179 +187 +191 +193 +195 +197 +199 +200 +203 +205 +207 +208 +209 +213 +216 +217 +219 +221 +223 +224 +229 +230 +233 +237 +238 +239 +242 +255 +256 +265 +272 +273 +277 +278 +280 +281 +282 +288 +298 +307 +309 +311 +316 +317 +318 +321 +322 +325 +327 +331 +333 +342 +344 +348 +353 +367 +369 +382 +384 +395 +396 +397 +399 +401 +403 +404 +406 +409 +413 +414 +417 +424 +429 +430 +431 +438 +439 +454 +458 +459 +462 +463 +466 +468 +469 +478 +480 +489 +492 +498 diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 new file mode 100644 index 0000000000000..f369f21e1833f --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 @@ -0,0 +1,2 @@ +100 100 2010-01-01 +200 200 2010-01-02 diff --git a/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 3132d0112c708..08da6405a17c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.hive.execution.HiveComparisonTest import org.apache.spark.sql.hive.test.TestHive class CachedTableSuite extends HiveComparisonTest { + import TestHive._ + TestHive.loadTestTable("src") test("cache table") { @@ -32,6 +34,20 @@ class CachedTableSuite extends HiveComparisonTest { createQueryTest("read from cached table", "SELECT * FROM src LIMIT 1", reset = false) + test("Drop cached table") { + hql("CREATE TABLE test(a INT)") + cacheTable("test") + hql("SELECT * FROM test").collect() + hql("DROP TABLE test") + intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { + hql("SELECT * FROM test").collect() + } + } + + test("DROP nonexistant table") { + hql("DROP TABLE IF EXISTS nonexistantTable") + } + test("check that table is cached and uncache") { TestHive.table("src").queryExecution.analyzed match { case _ : InMemoryRelation => // Found evidence of caching diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala new file mode 100644 index 0000000000000..a61fd9df95c94 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +class StatisticsSuite extends QueryTest { + + test("estimates the size of a test MetastoreRelation") { + val rdd = hql("""SELECT * FROM src""") + val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => + mr.statistics.sizeInBytes + } + assert(sizes.size === 1) + assert(sizes(0).equals(BigInt(5812)), + s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") + } + + test("auto converts to broadcast hash join, by size estimate of a relation") { + def mkTest( + before: () => Unit, + after: () => Unit, + query: String, + expectedAnswer: Seq[Any], + ct: ClassTag[_]) = { + before() + + var rdd = hql(query) + + // Assert src has a size smaller than the threshold. + val sizes = rdd.queryExecution.analyzed.collect { + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes + } + assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold, + s"query should contain two relations, each of which has size smaller than autoConvertSize") + + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + assert(bhj.size === 1, + s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + + checkAnswer(rdd, expectedAnswer) // check correctness of output + + TestHive.settings.synchronized { + val tmp = autoBroadcastJoinThreshold + + hql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + rdd = hql(query) + bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") + + val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + assert(shj.size === 1, + "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") + + hql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + } + + after() + } + + /** Tests for MetastoreRelation */ + val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key""" + val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238")) + mkTest( + () => (), + () => (), + metastoreQuery, + metastoreAnswer, + implicitly[ClassTag[MetastoreRelation]] + ) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index b4dbf2b115799..6c8fe4b196dea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -132,7 +132,7 @@ abstract class HiveComparisonTest answer: Seq[String]): Seq[String] = { def isSorted(plan: LogicalPlan): Boolean = plan match { - case _: Join | _: Aggregate | _: BaseRelation | _: Generate | _: Sample | _: Distinct => false + case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false case PhysicalOperation(_, _, Sort(_, _)) => true case _ => plan.children.iterator.exists(isSorted) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a8623b64c656f..89cc589fb8001 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.{Row, SchemaRDD} case class TestData(a: Int, b: String) @@ -30,6 +32,18 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("single case", + """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") + + createQueryTest("double case", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src LIMIT 1""") + + createQueryTest("case else null", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src LIMIT 1""") + + createQueryTest("having no references", + "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") + createQueryTest("boolean = number", """ |SELECT @@ -419,10 +433,10 @@ class HiveQuerySuite extends HiveComparisonTest { hql(s"set $testKey=$testVal") assert(get(testKey, testVal + "_") == testVal) - hql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - hql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + hql("set some.property=20") + assert(get("some.property", "0") == "20") + hql("set some.property = 40") + assert(get("some.property", "0") == "40") hql(s"set $testKey=$testVal") assert(get(testKey, "0") == testVal) @@ -436,63 +450,61 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet clear() // "set" itself returns all config variables currently specified in SQLConf. assert(hql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } hql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + hql(s"SET").collect().map(_.getString(0)) } // "set key" - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + hql(s"SET $nonexistentKey").collect().map(_.getString(0)) } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + sql("SET").collect().map(_.getString(0)) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + sql("SET").collect().map(_.getString(0)) } - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + sql(s"SET $nonexistentKey").collect().map(_.getString(0)) } clear() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala new file mode 100644 index 0000000000000..c5736723b47c0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +class HiveTableScanSuite extends HiveComparisonTest { + + createQueryTest("partition_based_table_scan_with_different_serde", + """ + |CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING) + |ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' + |STORED AS RCFILE; + | + |FROM src + |INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01') + |SELECT 100,100 LIMIT 1; + | + |ALTER TABLE part_scan_test SET SERDE + |'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe'; + | + |FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02') + |SELECT 200,200 LIMIT 1; + | + |SELECT * from part_scan_test; + """.stripMargin) +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 91ad59d7f82c0..47526e3596e44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} @@ -27,6 +29,8 @@ import org.apache.spark.util.Utils // Implicits import org.apache.spark.sql.hive.test.TestHive._ +case class Cases(lower: String, UPPER: String) + class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val dirname = Utils.createTempDir() @@ -35,7 +39,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft override def beforeAll() { // write test data - ParquetTestData.writeFile + ParquetTestData.writeFile() testRDD = parquetFile(ParquetTestData.testDir.toString) testRDD.registerAsTable("testsource") } @@ -55,6 +59,19 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft Utils.deleteRecursively(dirname) } + test("Case insensitive attribute names") { + val tempFile = File.createTempFile("parquet", "") + tempFile.delete() + sparkContext.parallelize(1 to 10) + .map(_.toString) + .map(i => Cases(i, i)) + .saveAsParquetFile(tempFile.getCanonicalPath) + + parquetFile(tempFile.getCanonicalPath).registerAsTable("cases") + hql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + hql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + } + test("SELECT on Parquet table") { val rdd = hql("SELECT * FROM testsource").collect() assert(rdd != null) diff --git a/streaming/pom.xml b/streaming/pom.xml index f60697ce745b7..1072f74aea0d9 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming_2.10 - streaming + streaming jar Spark Project Streaming @@ -58,6 +58,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 09be3a50d2dfa..1f0244c251eba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -138,7 +138,7 @@ private[streaming] abstract class ReceiverSupervisor( onReceiverStop(message, error) } catch { case t: Throwable => - stop("Error stopping receiver " + streamId, Some(t)) + logError("Error stopping receiver " + streamId + t.getStackTraceString) } } diff --git a/tools/pom.xml b/tools/pom.xml index c0ee8faa7a615..97abb6b2b63e0 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -27,7 +27,7 @@ org.apache.spark spark-tools_2.10 - tools + tools jar Spark Project Tools diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 566983675bff5..bcf6d43ab34eb 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -68,12 +68,11 @@ object GenerateMIMAIgnore { for (className <- classes) { try { val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader)) - val moduleSymbol = mirror.staticModule(className) // TODO: see if it is necessary. + val moduleSymbol = mirror.staticModule(className) val directlyPrivateSpark = isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol) - val developerApi = isDeveloperApi(classSymbol) - val experimental = isExperimental(classSymbol) - + val developerApi = isDeveloperApi(classSymbol) || isDeveloperApi(moduleSymbol) + val experimental = isExperimental(classSymbol) || isExperimental(moduleSymbol) /* Inner classes defined within a private[spark] class or object are effectively invisible, so we account for them as package private. */ lazy val indirectlyPrivateSpark = { @@ -87,10 +86,9 @@ object GenerateMIMAIgnore { } if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi || experimental) { ignoredClasses += className - } else { - // check if this class has package-private/annotated members. - ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } + // check if this class has package-private/annotated members. + ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } catch { case _: Throwable => println("Error instrumenting class:" + className) @@ -115,9 +113,11 @@ object GenerateMIMAIgnore { } private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { - classSymbol.typeSignature.members - .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ - getInnerFunctions(classSymbol) + classSymbol.typeSignature.members.filterNot(x => + x.fullName.startsWith("java") || x.fullName.startsWith("scala") + ).filter(x => + isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x) + ).map(_.fullName) ++ getInnerFunctions(classSymbol) } def main(args: Array[String]) { @@ -137,8 +137,7 @@ object GenerateMIMAIgnore { name.endsWith("$class") || name.contains("$sp") || name.contains("hive") || - name.contains("Hive") || - name.contains("repl") + name.contains("Hive") } /** diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 8e8c35615a711..8a05fcb449aa6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -61,10 +61,9 @@ object StoragePerfTester { for (i <- 1 to recordsPerMap) { writers(i % numOutputSplits).write(writeData) } - writers.map {w => - w.commit() + writers.map { w => + w.commitAndClose() total.addAndGet(w.fileSegment().length) - w.close() } shuffle.releaseWriters(true) diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 5b13a1f002d6e..51744ece0412d 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-alpha + yarn-alpha org.apache.spark diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index a1298e8f30b5c..b7e8636e02eb2 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -109,7 +109,7 @@ trait ClientBase extends Logging { if (amMem > maxMem) { val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." - .format(args.amMemory, maxMem) + .format(amMem, maxMem) logError(errorMessage) throw new IllegalArgumentException(errorMessage) } diff --git a/yarn/pom.xml b/yarn/pom.xml index efb473aa1b261..3faaf053634d6 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -29,7 +29,7 @@ pom Spark Project YARN Parent POM - yarn + yarn diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index ceaf9f9d71001..b6c8456d06684 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-stable + yarn-stable org.apache.spark