diff --git a/.gitignore b/.gitignore index 34939e3a97aaa..c67cffa1c4375 100644 --- a/.gitignore +++ b/.gitignore @@ -49,7 +49,7 @@ dependency-reduced-pom.xml checkpoint derby.log dist/ -spark-*-bin.tar.gz +spark-*-bin-*.tgz unit-tests.log /lib/ rat-results.txt diff --git a/.rat-excludes b/.rat-excludes index 20e3372464386..d8bee1f8e49c9 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -44,6 +44,7 @@ SparkImports.scala SparkJLineCompletion.scala SparkJLineReader.scala SparkMemberHandlers.scala +SparkReplReporter.scala sbt sbt-launch-lib.bash plugins.sbt diff --git a/assembly/pom.xml b/assembly/pom.xml index 31a01e4d8e1de..4e2b773e7d2f3 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -66,22 +66,22 @@ org.apache.spark - spark-repl_${scala.binary.version} + spark-streaming_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming_${scala.binary.version} + spark-graphx_${scala.binary.version} ${project.version} org.apache.spark - spark-graphx_${scala.binary.version} + spark-sql_${scala.binary.version} ${project.version} org.apache.spark - spark-sql_${scala.binary.version} + spark-repl_${scala.binary.version} ${project.version} @@ -197,6 +197,11 @@ spark-hive_${scala.binary.version} ${project.version} + + + + hive-thriftserver + org.apache.spark spark-hive-thriftserver_${scala.binary.version} diff --git a/bagel/pom.xml b/bagel/pom.xml index 93db0d5efda5f..0327ffa402671 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 905bbaf99b374..298641f2684de 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -20,8 +20,6 @@ # This script computes Spark's classpath and prints it to stdout; it's used by both the "run" # script and the ExecutorRunner in standalone cluster mode. -SCALA_VERSION=2.10 - # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -36,7 +34,7 @@ else CLASSPATH="$CLASSPATH:$FWDIR/conf" fi -ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION" +ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION" if [ -n "$JAVA_HOME" ]; then JAR_CMD="$JAVA_HOME/bin/jar" @@ -48,19 +46,19 @@ fi if [ -n "$SPARK_PREPEND_CLASSES" ]; then echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ "classes ahead of assembly." >&2 - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" fi # Use spark-assembly jar from either RELEASE or assembly directory @@ -123,15 +121,15 @@ fi # Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 if [[ $SPARK_TESTING == 1 ]]; then - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" fi # Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 6d4231b204595..356b3d49b2ffe 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -36,3 +36,23 @@ if [ -z "$SPARK_ENV_LOADED" ]; then set +a fi fi + +# Setting SPARK_SCALA_VERSION if not already set. + +if [ -z "$SPARK_SCALA_VERSION" ]; then + + ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" + ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" + + if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then + echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 + echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 + exit 1 + fi + + if [ -d "$ASSEMBLY_DIR2" ]; then + export SPARK_SCALA_VERSION="2.11" + else + export SPARK_SCALA_VERSION="2.10" + fi +fi diff --git a/bin/pyspark b/bin/pyspark index 96f30a260a09e..0b4f695dd06dd 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,7 +25,7 @@ export SPARK_HOME="$FWDIR" source "$FWDIR/bin/utils.sh" -SCALA_VERSION=2.10 +source "$FWDIR"/bin/load-spark-env.sh function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 @@ -40,7 +40,7 @@ fi # Exit if the user hasn't compiled Spark if [ ! -f "$FWDIR/RELEASE" ]; then # Exit if the user hasn't compiled Spark - ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null + ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null if [[ $? != 0 ]]; then echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 @@ -48,8 +48,6 @@ if [ ! -f "$FWDIR/RELEASE" ]; then fi fi -. "$FWDIR"/bin/load-spark-env.sh - # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. # @@ -134,7 +132,5 @@ if [[ "$1" =~ \.py$ ]]; then gatherSparkSubmitOpts "$@" exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}" else - # PySpark shell requires special handling downstream - export PYSPARK_SHELL=1 exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 59415e9bdec2c..a542ec80b49d6 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -59,7 +59,6 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do ( ) if [%PYTHON_FILE%] == [] ( - set PYSPARK_SHELL=1 if [%IPYTHON%] == [1] ( ipython %IPYTHON_OPTS% ) else ( diff --git a/bin/run-example b/bin/run-example index 34dd71c71880e..3d932509426fc 100755 --- a/bin/run-example +++ b/bin/run-example @@ -17,12 +17,12 @@ # limitations under the License. # -SCALA_VERSION=2.10 - FWDIR="$(cd "`dirname "$0"`"/..; pwd)" export SPARK_HOME="$FWDIR" EXAMPLES_DIR="$FWDIR"/examples +. "$FWDIR"/bin/load-spark-env.sh + if [ -n "$1" ]; then EXAMPLE_CLASS="$1" shift @@ -36,8 +36,8 @@ fi if [ -f "$FWDIR/RELEASE" ]; then export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`" -elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then - export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`" +elif [ -e "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar ]; then + export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar`" fi if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then diff --git a/bin/spark-class b/bin/spark-class index 925367b0dd187..0d58d95c1aee3 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -24,8 +24,6 @@ case "`uname`" in CYGWIN*) cygwin=true;; esac -SCALA_VERSION=2.10 - # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -128,9 +126,9 @@ fi TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" -if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then +if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the SBT build - export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`" + export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`" fi if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the Maven build @@ -149,7 +147,7 @@ fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then if test -z "$SPARK_TOOLS_JAR"; then - echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2 + echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2 echo "You need to build Spark before running $1." 1>&2 exit 1 fi diff --git a/bin/spark-sql b/bin/spark-sql index 63d00437d508d..3b6cc420fea81 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -23,6 +23,8 @@ # Enter posix mode for bash set -o posix +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" # Figure out where Spark is installed diff --git a/bin/spark-submit b/bin/spark-submit index c557311b4b20e..f92d90c3a66b0 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,6 +22,9 @@ export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" ORIG_ARGS=("$@") +# Set COLUMNS for progress bar +export COLUMNS=`tput cols` + while (($#)); do if [ "$1" = "--deploy-mode" ]; then SPARK_SUBMIT_DEPLOY_MODE=$2 diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index f8ffbf64278fb..0886b0276fb90 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -28,7 +28,7 @@ # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. # - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job. -# Options for the daemons used in the standalone deploy mode: +# Options for the daemons used in the standalone deploy mode # - SPARK_MASTER_IP, to bind the master to a different IP address or hostname # - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports for the master # - SPARK_MASTER_OPTS, to set config properties only for the master (e.g. "-Dx=y") @@ -41,3 +41,10 @@ # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers + +# Generic options for the daemons used in the standalone deploy mode +# - SPARK_CONF_DIR Alternate conf dir. (Default: ${SPARK_HOME}/conf) +# - SPARK_LOG_DIR Where log files are stored. (Default: ${SPARK_HOME}/logs) +# - SPARK_PID_DIR Where the pid file is stored. (Default: /tmp) +# - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER) +# - SPARK_NICENESS The scheduling priority for daemons. (Default: 0) diff --git a/core/pom.xml b/core/pom.xml index 41296e0eca330..1feb00b3a7fb8 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -34,6 +34,34 @@ Spark Project Core http://spark.apache.org/ + + com.twitter + chill_${scala.binary.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + + + com.twitter + chill-java + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + org.apache.hadoop hadoop-client @@ -46,12 +74,12 @@ org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_${scala.binary.version} ${project.version} @@ -132,14 +160,6 @@ net.jpountz.lz4 lz4 - - com.twitter - chill_${scala.binary.version} - - - com.twitter - chill-java - org.roaringbitmap RoaringBitmap @@ -309,14 +329,16 @@ org.scalatest scalatest-maven-plugin - - - ${basedir}/.. - 1 - ${spark.classpath} - - + + + test + + test + + + + org.apache.maven.plugins @@ -424,4 +446,5 @@ + diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java index 04e2247210ecc..fd74321093658 100644 --- a/core/src/main/java/org/apache/spark/SparkStageInfo.java +++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java @@ -26,6 +26,7 @@ public interface SparkStageInfo { int stageId(); int currentAttemptId(); + long submissionTime(); String name(); int numTasks(); int numActiveTasks(); diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/core/src/main/java/org/apache/spark/api/java/function/package.scala index 7f91de653a64a..0f9bac7164162 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/package.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/package.scala @@ -22,4 +22,4 @@ package org.apache.spark.api.java * these interfaces to pass functions to various Java API methods for Spark. Please visit Spark's * Java programming guide for more details. */ -package object function \ No newline at end of file +package object function diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index badd85ed48c82..14ba37d7c9bd9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -26,26 +26,24 @@ $(function() { // Switch the class of the arrow from open to closed. $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); - - // If clicking caused the metrics to expand, automatically check all options for additional - // metrics (don't trigger a click when collapsing metrics, because it leads to weird - // toggling behavior). - if (!$(additionalMetricsDiv).hasClass('collapsed')) { - $(this).parent().find('input:checkbox:not(:checked)').trigger('click'); - } }); - $("input:checkbox:not(:checked)").each(function() { - var column = "table ." + $(this).attr("name"); - $(column).hide(); - }); - // Stripe table rows after rows have been hidden to ensure correct striping. - stripeTables(); + stripeSummaryTable(); $("input:checkbox").click(function() { var column = "table ." + $(this).attr("name"); $(column).toggle(); - stripeTables(); + stripeSummaryTable(); + }); + + $("#select-all-metrics").click(function() { + if (this.checked) { + // Toggle all un-checked options. + $('input:checkbox:not(:checked)').trigger('click'); + } else { + // Toggle all checked options. + $('input:checkbox:checked').trigger('click'); + } }); // Trigger a click on the checkbox if a user clicks the label next to it. diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 6bb03015abb51..656147e40d13e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -15,16 +15,18 @@ * limitations under the License. */ -/* Adds background colors to stripe table rows. This is necessary (instead of using css or the - * table striping provided by bootstrap) to appropriately stripe tables with hidden rows. */ -function stripeTables() { - $("table.table-striped-custom").each(function() { - $(this).find("tr:not(:hidden)").each(function (index) { - if (index % 2 == 1) { - $(this).css("background-color", "#f9f9f9"); - } else { - $(this).css("background-color", "#ffffff"); - } - }); +/* Adds background colors to stripe table rows in the summary table (on the stage page). This is + * necessary (instead of using css or the table striping provided by bootstrap) because the summary + * table has hidden rows. + * + * An ID selector (rather than a class selector) is used to ensure this runs quickly even on pages + * with thousands of task rows (ID selectors are much faster than class selectors). */ +function stripeSummaryTable() { + $("#task-summary-table").find("tr:not(:hidden)").each(function (index) { + if (index % 2 == 1) { + $(this).css("background-color", "#f9f9f9"); + } else { + $(this).css("background-color", "#ffffff"); + } }); } diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index db57712c83503..cdf85bfbf326f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -168,3 +168,9 @@ span.additional-metric-title { border-left: 5px solid black; display: inline-block; } + +/* Hide all additional metrics by default. This is done here rather than using JavaScript to + * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ +.scheduler_delay, .gc_time, .deserialization_time, .serialization_time, .getting_result_time { + display: none; +} diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 2301caafb07ff..000bbd6b532ad 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} +import java.util.concurrent.atomic.AtomicLong import scala.collection.generic.Growable import scala.collection.mutable.Map @@ -228,6 +229,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa */ class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) extends Accumulable[T,T](initialValue, param, name) { + def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) } @@ -244,6 +246,36 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] { } } +object AccumulatorParam { + + // The following implicit objects were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, as there are duplicate codes in SparkContext for backward + // compatibility, please update them accordingly if you modify the following implicit objects. + + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + def addInPlace(t1: Double, t2: Double): Double = t1 + t2 + def zero(initialValue: Double) = 0.0 + } + + implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + def addInPlace(t1: Int, t2: Int): Int = t1 + t2 + def zero(initialValue: Int) = 0 + } + + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long) = t1 + t2 + def zero(initialValue: Long) = 0L + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float) = t1 + t2 + def zero(initialValue: Float) = 0f + } + + // TODO: Add AccumulatorParams for other types, e.g. lists and strings +} + // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private object Accumulators { @@ -252,7 +284,7 @@ private object Accumulators { val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]() var lastId: Long = 0 - def newId: Long = synchronized { + def newId(): Long = synchronized { lastId += 1 lastId } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index ef93009a074e7..88adb892998af 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -28,7 +28,9 @@ import org.apache.spark.scheduler._ * the scheduler queue is not drained in N seconds, then new executors are added. If the queue * persists for another M seconds, then more executors are added and so on. The number added * in each round increases exponentially from the previous round until an upper bound on the - * number of executors has been reached. + * number of executors has been reached. The upper bound is based both on a configured property + * and on the number of tasks pending: the policy will never increase the number of executor + * requests past the number needed to handle all pending tasks. * * The rationale for the exponential increase is twofold: (1) Executors should be added slowly * in the beginning in case the number of extra executors needed turns out to be small. Otherwise, @@ -82,6 +84,12 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + // TODO: The default value of 1 for spark.executor.cores works right now because dynamic + // allocation is only supported for YARN and the default number of cores per executor in YARN is + // 1, but it might need to be attained differently for different cluster managers + private val tasksPerExecutor = + conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + validateSettings() // Number of executors to add in the next round @@ -110,6 +118,9 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Clock used to schedule when executors should be added and removed private var clock: Clock = new RealClock + // Listener for Spark events that impact the allocation policy + private val listener = new ExecutorAllocationListener(this) + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -141,6 +152,9 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } + if (tasksPerExecutor == 0) { + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores") + } } /** @@ -154,7 +168,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging * Register for scheduler callbacks to decide when to add and remove executors. */ def start(): Unit = { - val listener = new ExecutorAllocationListener(this) sc.addSparkListener(listener) startPolling() } @@ -218,13 +231,27 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging return 0 } - // Request executors with respect to the upper bound - val actualNumExecutorsToAdd = - if (numExistingExecutors + numExecutorsToAdd <= maxNumExecutors) { - numExecutorsToAdd - } else { - maxNumExecutors - numExistingExecutors - } + // The number of executors needed to satisfy all pending tasks is the number of tasks pending + // divided by the number of tasks each executor can fit, rounded up. + val maxNumExecutorsPending = + (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor + if (numExecutorsPending >= maxNumExecutorsPending) { + logDebug(s"Not adding executors because there are already $numExecutorsPending " + + s"pending and pending tasks could only fill $maxNumExecutorsPending") + numExecutorsToAdd = 1 + return 0 + } + + // It's never useful to request more executors than could satisfy all the pending tasks, so + // cap request at that amount. + // Also cap request with respect to the configured upper bound. + val maxNumExecutorsToAdd = math.min( + maxNumExecutorsPending - numExecutorsPending, + maxNumExecutors - numExistingExecutors) + assert(maxNumExecutorsToAdd > 0) + + val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd) + val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd) if (addRequestAcknowledged) { @@ -445,6 +472,16 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) } + + /** + * An estimate of the total number of pending tasks remaining for currently running stages. Does + * not account for tasks which may have failed and been resubmitted. + */ + def totalPendingTasks(): Int = { + stageIdToNumTasks.map { case (stageId, numTasks) => + numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0) + }.sum + } } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 4c6c86c7bad78..c14764f773982 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -61,7 +61,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { throw new NullPointerException("null key") } if (value == null) { - throw new NullPointerException("null value") + throw new NullPointerException("null value for " + key) } settings(key) = value this diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 03ea672c813d1..9b0d5be7a7ab2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -25,6 +25,7 @@ import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger import java.util.UUID.randomUUID import scala.collection.{Map, Set} +import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} @@ -49,7 +50,7 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkD import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ @@ -57,17 +58,33 @@ import org.apache.spark.util._ * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. + * * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. */ +class SparkContext(config: SparkConf) extends Logging { + + // The call site where this SparkContext was constructed. + private val creationSite: CallSite = Utils.getCallSite() + + // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active + private val allowMultipleContexts: Boolean = + config.getBoolean("spark.driver.allowMultipleContexts", false) -class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having started construction. + // NOTE: this must be placed at the beginning of the SparkContext constructor. + SparkContext.markPartiallyConstructed(this, allowMultipleContexts) // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It // contains a map from hostname to a list of input format splits on the host. private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map() + val startTime = System.currentTimeMillis() + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -228,6 +245,15 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { private[spark] val jobProgressListener = new JobProgressListener(conf) listenerBus.addListener(jobProgressListener) + val statusTracker = new SparkStatusTracker(this) + + private[spark] val progressBar: Option[ConsoleProgressBar] = + if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + // Initialize the Spark UI private[spark] val ui: Option[SparkUI] = if (conf.getBoolean("spark.ui.enabled", true)) { @@ -245,8 +271,6 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) - val startTime = System.currentTimeMillis() - // Add each JAR given through the constructor if (jars != null) { jars.foreach(addJar) @@ -1001,6 +1025,69 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** The version of Spark on which this application is running. */ def version = SPARK_VERSION + /** + * Return a map from the slave to the max memory available for caching and the remaining + * memory available for caching. + */ + def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => + (blockManagerId.host + ":" + blockManagerId.port, mem) + } + } + + /** + * :: DeveloperApi :: + * Return information about what RDDs are cached, if they are in mem or on disk, how much space + * they take, etc. + */ + @DeveloperApi + def getRDDStorageInfo: Array[RDDInfo] = { + val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.filter(_.isCached) + } + + /** + * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. + * Note that this does not necessarily mean the caching or computation was successful. + */ + def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap + + /** + * :: DeveloperApi :: + * Return information about blocks stored in all of the slaves + */ + @DeveloperApi + def getExecutorStorageStatus: Array[StorageStatus] = { + env.blockManager.master.getStorageStatus + } + + /** + * :: DeveloperApi :: + * Return pools for fair scheduler + */ + @DeveloperApi + def getAllPools: Seq[Schedulable] = { + // TODO(xiajunluan): We should take nested pools into account + taskScheduler.rootPool.schedulableQueue.toSeq + } + + /** + * :: DeveloperApi :: + * Return the pool associated with the given name, if one exists + */ + @DeveloperApi + def getPoolForName(pool: String): Option[Schedulable] = { + Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) + } + + /** + * Return current scheduling mode + */ + def getSchedulingMode: SchedulingMode.SchedulingMode = { + taskScheduler.schedulingMode + } + /** * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. @@ -1100,27 +1187,30 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** Shut down the SparkContext. */ def stop() { - postApplicationEnd() - ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { - env.metricsSystem.report() - metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) - cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - listenerBus.stop() - eventLogger.foreach(_.stop()) - logInfo("Successfully stopped SparkContext") - } else { - logInfo("SparkContext already stopped") + SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + postApplicationEnd() + ui.foreach(_.stop()) + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { + env.metricsSystem.report() + metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) + cleaner.foreach(_.stop()) + dagSchedulerCopy.stop() + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + SparkEnv.set(null) + listenerBus.stop() + eventLogger.foreach(_.stop()) + logInfo("Successfully stopped SparkContext") + SparkContext.clearActiveContext() + } else { + logInfo("SparkContext already stopped") + } } } @@ -1191,6 +1281,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { logInfo("Starting job: " + callSite.shortForm) dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) + progressBar.foreach(_.finishAll()) rdd.doCheckpoint() } @@ -1409,6 +1500,11 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { private[spark] def cleanup(cleanupTime: Long) { persistentRdds.clearOldValues(cleanupTime) } + + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having finished construction. + // NOTE: this must be placed at the end of the SparkContext constructor. + SparkContext.setActiveContext(this, allowMultipleContexts) } /** @@ -1417,6 +1513,107 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { */ object SparkContext extends Logging { + /** + * Lock that guards access to global variables that track SparkContext construction. + */ + private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() + + /** + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var activeContext: Option[SparkContext] = None + + /** + * Points to a partially-constructed SparkContext if some thread is in the SparkContext + * constructor, or `None` if no SparkContext is being constructed. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var contextBeingConstructed: Option[SparkContext] = None + + /** + * Called to ensure that no other SparkContext is running in this JVM. + * + * Throws an exception if a running context is detected and logs a warning if another thread is + * constructing a SparkContext. This warning is necessary because the current locking scheme + * prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private def assertNoOtherContextIsRunning( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + contextBeingConstructed.foreach { otherContext => + if (otherContext ne sc) { // checks for reference equality + // Since otherContext might point to a partially-constructed context, guard against + // its creationSite field being null: + val otherContextCreationSite = + Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location") + val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" + + " constructor). This may indicate an error, since only one SparkContext may be" + + " running in this JVM (see SPARK-2243)." + + s" The other SparkContext was created at:\n$otherContextCreationSite" + logWarning(warnMsg) + } + + activeContext.foreach { ctx => + val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" + val exception = new SparkException(errMsg) + if (allowMultipleContexts) { + logWarning("Multiple running SparkContexts detected in the same JVM!", exception) + } else { + throw exception + } + } + } + } + } + + /** + * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is + * running. Throws an exception if a running context is detected and logs a warning if another + * thread is constructing a SparkContext. This warning is necessary because the current locking + * scheme prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private[spark] def markPartiallyConstructed( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = Some(sc) + } + } + + /** + * Called at the end of the SparkContext constructor to ensure that no other SparkContext has + * raced with this constructor and started. + */ + private[spark] def setActiveContext( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = None + activeContext = Some(sc) + } + } + + /** + * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's + * also called in unit tests to prevent a flood of warnings from test suites that don't / can't + * properly clean up their SparkContexts. + */ + private[spark] def clearActiveContext(): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + activeContext = None + } + } + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" @@ -1427,47 +1624,74 @@ object SparkContext extends Logging { private[spark] val DRIVER_IDENTIFIER = "" - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + // The following deprecated objects have already been copied to `object AccumulatorParam` to + // make the compiler find them automatically. They are duplicate codes only for backward + // compatibility, please update `object AccumulatorParam` accordingly if you plan to modify the + // following ones. + + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 } - implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 def zero(initialValue: Int) = 0 } - implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object LongAccumulatorParam extends AccumulatorParam[Long] { def addInPlace(t1: Long, t2: Long) = t1 + t2 def zero(initialValue: Long) = 0L } - implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object FloatAccumulatorParam extends AccumulatorParam[Float] { def addInPlace(t1: Float, t2: Float) = t1 + t2 def zero(initialValue: Float) = 0f } - // TODO: Add AccumulatorParams for other types, e.g. lists and strings + // The following deprecated functions have already been moved to `object RDD` to + // make the compiler find them automatically. They are still kept here for backward compatibility + // and just call the corresponding functions in `object RDD`. - implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { - new PairRDDFunctions(rdd) + RDD.rddToPairRDDFunctions(rdd) } - implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( rdd: RDD[(K, V)]) = - new SequenceFileRDDFunctions(rdd) + RDD.rddToSequenceFileRDDFunctions(rdd) - implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( rdd: RDD[(K, V)]) = - new OrderedRDDFunctions[K, V, (K, V)](rdd) + RDD.rddToOrderedRDDFunctions(rdd) - implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) - implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = - new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + RDD.numericRDDToDoubleRDDFunctions(rdd) // Implicit conversions to common Writable types, for saveAsSequenceFile @@ -1493,40 +1717,49 @@ object SparkContext extends Logging { arr.map(x => anyToWritable(x)).toArray) } - // Helper objects for converting common types to Writable - private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) - : WritableConverter[T] = { - val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] - new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) - } + // The following deprecated functions have already been moved to `object WritableConverter` to + // make the compiler find them automatically. They are still kept here for backward compatibility + // and just call the corresponding functions in `object WritableConverter`. - implicit def intWritableConverter(): WritableConverter[Int] = - simpleWritableConverter[Int, IntWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def intWritableConverter(): WritableConverter[Int] = + WritableConverter.intWritableConverter() - implicit def longWritableConverter(): WritableConverter[Long] = - simpleWritableConverter[Long, LongWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def longWritableConverter(): WritableConverter[Long] = + WritableConverter.longWritableConverter() - implicit def doubleWritableConverter(): WritableConverter[Double] = - simpleWritableConverter[Double, DoubleWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def doubleWritableConverter(): WritableConverter[Double] = + WritableConverter.doubleWritableConverter() - implicit def floatWritableConverter(): WritableConverter[Float] = - simpleWritableConverter[Float, FloatWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def floatWritableConverter(): WritableConverter[Float] = + WritableConverter.floatWritableConverter() - implicit def booleanWritableConverter(): WritableConverter[Boolean] = - simpleWritableConverter[Boolean, BooleanWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def booleanWritableConverter(): WritableConverter[Boolean] = + WritableConverter.booleanWritableConverter() - implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { - simpleWritableConverter[Array[Byte], BytesWritable](bw => - // getBytes method returns array which is longer then data to be returned - Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) - ) - } + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def bytesWritableConverter(): WritableConverter[Array[Byte]] = + WritableConverter.bytesWritableConverter() - implicit def stringWritableConverter(): WritableConverter[String] = - simpleWritableConverter[String, Text](_.toString) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def stringWritableConverter(): WritableConverter[String] = + WritableConverter.stringWritableConverter() - implicit def writableWritableConverter[T <: Writable]() = - new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def writableWritableConverter[T <: Writable]() = + WritableConverter.writableWritableConverter() /** * Find the JAR from which a given class was loaded, to make it easy for users to pass @@ -1616,6 +1849,9 @@ object SparkContext extends Logging { def localCpuCount = Runtime.getRuntime.availableProcessors() // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. val threadCount = if (threads == "*") localCpuCount else threads.toInt + if (threadCount <= 0) { + throw new SparkException(s"Asked to run locally with $threadCount threads") + } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) @@ -1750,3 +1986,46 @@ private[spark] class WritableConverter[T]( val writableClass: ClassTag[T] => Class[_ <: Writable], val convert: Writable => T) extends Serializable + +object WritableConverter { + + // Helper objects for converting common types to Writable + private[spark] def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) + : WritableConverter[T] = { + val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] + new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) + } + + // The following implicit functions were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, we still keep the old functions in SparkContext for backward + // compatibility and forward to the following functions directly. + + implicit def intWritableConverter(): WritableConverter[Int] = + simpleWritableConverter[Int, IntWritable](_.get) + + implicit def longWritableConverter(): WritableConverter[Long] = + simpleWritableConverter[Long, LongWritable](_.get) + + implicit def doubleWritableConverter(): WritableConverter[Double] = + simpleWritableConverter[Double, DoubleWritable](_.get) + + implicit def floatWritableConverter(): WritableConverter[Float] = + simpleWritableConverter[Float, FloatWritable](_.get) + + implicit def booleanWritableConverter(): WritableConverter[Boolean] = + simpleWritableConverter[Boolean, BooleanWritable](_.get) + + implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { + simpleWritableConverter[Array[Byte], BytesWritable](bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + ) + } + + implicit def stringWritableConverter(): WritableConverter[String] = + simpleWritableConverter[String, Text](_.toString) + + implicit def writableWritableConverter[T <: Writable]() = + new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e7454beddbfd0..e464b32e61dd6 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -168,9 +168,11 @@ object SparkEnv extends Logging { executorId: String, hostname: String, port: Int, + numCores: Int, isLocal: Boolean, actorSystem: ActorSystem = null): SparkEnv = { - create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem) + create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem, + numUsableCores = numCores) } /** @@ -184,7 +186,8 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean, listenerBus: LiveListenerBus = null, - defaultActorSystem: ActorSystem = null): SparkEnv = { + defaultActorSystem: ActorSystem = null, + numUsableCores: Int = 0): SparkEnv = { // Listener bus is only used on the driver if (isDriver) { @@ -276,7 +279,7 @@ object SparkEnv extends Logging { val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => - new NettyBlockTransferService(conf, securityManager) + new NettyBlockTransferService(conf, securityManager, numUsableCores) case "nio" => new NioBlockTransferService(conf, securityManager) } @@ -287,7 +290,8 @@ object SparkEnv extends Logging { // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, + numUsableCores) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala b/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala deleted file mode 100644 index 1982499c5e1d3..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.Map -import scala.collection.JavaConversions._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{SchedulingMode, Schedulable} -import org.apache.spark.storage.{StorageStatus, StorageUtils, RDDInfo} - -/** - * Trait that implements Spark's status APIs. This trait is designed to be mixed into - * SparkContext; it allows the status API code to live in its own file. - */ -private[spark] trait SparkStatusAPI { this: SparkContext => - - /** - * Return a map from the slave to the max memory available for caching and the remaining - * memory available for caching. - */ - def getExecutorMemoryStatus: Map[String, (Long, Long)] = { - env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.host + ":" + blockManagerId.port, mem) - } - } - - /** - * :: DeveloperApi :: - * Return information about what RDDs are cached, if they are in mem or on disk, how much space - * they take, etc. - */ - @DeveloperApi - def getRDDStorageInfo: Array[RDDInfo] = { - val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) - rddInfos.filter(_.isCached) - } - - /** - * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. - */ - def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - - /** - * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves - */ - @DeveloperApi - def getExecutorStorageStatus: Array[StorageStatus] = { - env.blockManager.master.getStorageStatus - } - - /** - * :: DeveloperApi :: - * Return pools for fair scheduler - */ - @DeveloperApi - def getAllPools: Seq[Schedulable] = { - // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq - } - - /** - * :: DeveloperApi :: - * Return the pool associated with the given name, if one exists - */ - @DeveloperApi - def getPoolForName(pool: String): Option[Schedulable] = { - Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) - } - - /** - * Return current scheduling mode - */ - def getSchedulingMode: SchedulingMode.SchedulingMode = { - taskScheduler.schedulingMode - } - - - /** - * Return a list of all known jobs in a particular job group. The returned list may contain - * running, failed, and completed jobs, and may vary across invocations of this method. This - * method does not guarantee the order of the elements in its result. - */ - def getJobIdsForGroup(jobGroup: String): Array[Int] = { - jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.exists(_ == jobGroup)).map(_.jobId).toArray - } - } - - /** - * Returns job information, or `None` if the job info could not be found or was garbage collected. - */ - def getJobInfo(jobId: Int): Option[SparkJobInfo] = { - jobProgressListener.synchronized { - jobProgressListener.jobIdToData.get(jobId).map { data => - new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) - } - } - } - - /** - * Returns stage information, or `None` if the stage info could not be found or was - * garbage collected. - */ - def getStageInfo(stageId: Int): Option[SparkStageInfo] = { - jobProgressListener.synchronized { - for ( - info <- jobProgressListener.stageIdToInfo.get(stageId); - data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) - ) yield { - new SparkStageInfoImpl( - stageId, - info.attemptId, - info.name, - info.numTasks, - data.numActiveTasks, - data.numCompleteTasks, - data.numFailedTasks) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala new file mode 100644 index 0000000000000..edbdda8a0bcb6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Low-level status reporting APIs for monitoring job and stage progress. + * + * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should + * be prepared to handle empty / missing information. For example, a job's stage ids may be known + * but the status API may not have any information about the details of those stages, so + * `getStageInfo` could potentially return `None` for a valid stage id. + * + * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs + * will provide information for the last `spark.ui.retainedStages` stages and + * `spark.ui.retainedJobs` jobs. + * + * NOTE: this class's constructor should be considered private and may be subject to change. + */ +class SparkStatusTracker private[spark] (sc: SparkContext) { + + private val jobProgressListener = sc.jobProgressListener + + /** + * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then + * returns all known jobs that are not associated with a job group. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = { + jobProgressListener.synchronized { + val jobData = jobProgressListener.jobIdToData.valuesIterator + jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + } + } + + /** + * Returns an array containing the ids of all active stages. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveStageIds(): Array[Int] = { + jobProgressListener.synchronized { + jobProgressListener.activeStages.values.map(_.stageId).toArray + } + } + + /** + * Returns an array containing the ids of all active jobs. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveJobIds(): Array[Int] = { + jobProgressListener.synchronized { + jobProgressListener.activeJobs.values.map(_.jobId).toArray + } + } + + /** + * Returns job information, or `None` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): Option[SparkJobInfo] = { + jobProgressListener.synchronized { + jobProgressListener.jobIdToData.get(jobId).map { data => + new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) + } + } + } + + /** + * Returns stage information, or `None` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): Option[SparkStageInfo] = { + jobProgressListener.synchronized { + for ( + info <- jobProgressListener.stageIdToInfo.get(stageId); + data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) + ) yield { + new SparkStageInfoImpl( + stageId, + info.attemptId, + info.submissionTime.getOrElse(0), + info.name, + info.numTasks, + data.numActiveTasks, + data.numCompleteTasks, + data.numFailedTasks) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala index 90b47c847fbca..e5c7c8d0db578 100644 --- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -26,6 +26,7 @@ private class SparkJobInfoImpl ( private class SparkStageInfoImpl( val stageId: Int, val currentAttemptId: Int, + val submissionTime: Long, val name: String, val numTasks: Int, val numActiveTasks: Int, diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index e37f3acaf6e30..7af3538262fd6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -32,13 +32,13 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ -import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} +import org.apache.spark.rdd.RDD.rddToPairRDDFunctions import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 5c6e8d32c5c8a..97f5c9f257e09 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.SparkContext._ +import org.apache.spark.AccumulatorParam._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast @@ -42,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. + * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. */ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround with Closeable { @@ -105,6 +108,8 @@ class JavaSparkContext(val sc: SparkContext) private[spark] val env = sc.env + def statusTracker = new JavaSparkStatusTracker(sc) + def isLocal: java.lang.Boolean = sc.isLocal def sparkUser: String = sc.sparkUser @@ -134,25 +139,6 @@ class JavaSparkContext(val sc: SparkContext) /** Default min number of partitions for Hadoop RDDs when not given by user */ def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions - - /** - * Return a list of all known jobs in a particular job group. The returned list may contain - * running, failed, and completed jobs, and may vary across invocations of this method. This - * method does not guarantee the order of the elements in its result. - */ - def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.getJobIdsForGroup(jobGroup) - - /** - * Returns job information, or `null` if the job info could not be found or was garbage collected. - */ - def getJobInfo(jobId: Int): SparkJobInfo = sc.getJobInfo(jobId).orNull - - /** - * Returns stage information, or `null` if the stage info could not be found or was - * garbage collected. - */ - def getStageInfo(stageId: Int): SparkStageInfo = sc.getStageInfo(stageId).orNull - /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala new file mode 100644 index 0000000000000..3300cad9efbab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext} + +/** + * Low-level status reporting APIs for monitoring job and stage progress. + * + * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should + * be prepared to handle empty / missing information. For example, a job's stage ids may be known + * but the status API may not have any information about the details of those stages, so + * `getStageInfo` could potentially return `null` for a valid stage id. + * + * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs + * will provide information for the last `spark.ui.retainedStages` stages and + * `spark.ui.retainedJobs` jobs. + * + * NOTE: this class's constructor should be considered private and may be subject to change. + */ +class JavaSparkStatusTracker private[spark] (sc: SparkContext) { + + /** + * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then + * returns all known jobs that are not associated with a job group. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.statusTracker.getJobIdsForGroup(jobGroup) + + /** + * Returns an array containing the ids of all active stages. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveStageIds(): Array[Int] = sc.statusTracker.getActiveStageIds() + + /** + * Returns an array containing the ids of all active jobs. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveJobIds(): Array[Int] = sc.statusTracker.getActiveJobIds() + + /** + * Returns job information, or `null` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): SparkJobInfo = sc.statusTracker.getJobInfo(jobId).orNull + + /** + * Returns stage information, or `null` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): SparkStageInfo = sc.statusTracker.getStageInfo(stageId).orNull +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 45beb8fc8c925..e0bc00e1eb249 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections} import org.apache.spark.input.PortableDataStream @@ -47,7 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], + broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -230,8 +230,7 @@ private[spark] class PythonRDD( if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + PythonRDD.writeUTF(broadcast.value.path, dataOut) oldBids.add(broadcast.id) } } @@ -368,16 +367,8 @@ private[spark] object PythonRDD extends Logging { } } - def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { - val file = new DataInputStream(new FileInputStream(filename)) - try { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - sc.broadcast(obj) - } finally { - file.close() - } + def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = { + sc.broadcast(new PythonBroadcast(path)) } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { @@ -816,3 +807,49 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: } } } + +/** + * An Wrapper for Python Broadcast, which is written into disk by Python. It also will + * write the data into disk after deserialization, then Python can read it from disks. + */ +private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { + + /** + * Read data from disks, then copy it to `out` + */ + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + val in = new FileInputStream(new File(path)) + try { + Utils.copyStream(in, out) + } finally { + in.close() + } + } + + /** + * Write data into disk, using randomly generated name. + */ + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + val dir = new File(Utils.getLocalDir(SparkEnv.get.conf)) + val file = File.createTempFile("broadcast", "", dir) + path = file.getAbsolutePath + val out = new FileOutputStream(file) + try { + Utils.copyStream(in, out) + } finally { + out.close() + } + } + + /** + * Delete the file once the object is GCed. + */ + override def finalize() { + if (!path.isEmpty) { + val file = new File(path) + if (file.exists()) { + file.delete() + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 87f5cf944ed85..a5ea478f231d7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -39,7 +39,7 @@ import scala.reflect.ClassTag * * {{{ * scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) - * broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c) + * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0) * * scala> broadcastVar.value * res0: Array[Int] = Array(1, 2, 3) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 7dade04273b08..31f0a462f84d8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -191,10 +191,12 @@ private[broadcast] object HttpBroadcast extends Logging { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) uc = newuri.toURL.openConnection() + uc.setConnectTimeout(httpReadTimeout) uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") uc = new URL(url).openConnection() + uc.setConnectTimeout(httpReadTimeout) } val in = { diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 4e802e02c4149..2e1e52906ceeb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -75,7 +75,8 @@ private[spark] class ClientArguments(args: Array[String]) { if (!ClientArguments.isValidJarUrl(_jarUrl)) { println(s"Jar url '${_jarUrl}' is not in valid format.") - println(s"Must be a jar file path in URL format (e.g. hdfs://XX.jar, file://XX.jar)") + println(s"Must be a jar file path in URL format " + + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") printUsageAndExit(-1) } @@ -119,7 +120,7 @@ object ClientArguments { def isValidJarUrl(s: String): Boolean = { try { val uri = new URI(s) - uri.getScheme != null && uri.getAuthority != null && s.endsWith("jar") + uri.getScheme != null && uri.getPath != null && uri.getPath.endsWith(".jar") } catch { case _: URISyntaxException => false } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index b9dd8557ee904..c46f84de8444a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -92,6 +92,8 @@ private[deploy] object DeployMessages { case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object ReregisterWithMaster // used when a worker attempts to reconnect to a master + // AppClient to Master case class RegisterApplication(appDescription: ApplicationDescription) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index af94b05ce3847..039c8719e2867 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -87,8 +87,8 @@ object PythonRunner { // Strip the URI scheme from the path formattedPath = new URI(formattedPath).getScheme match { - case Utils.windowsDrive(d) if windows => formattedPath case null => formattedPath + case Utils.windowsDrive(d) if windows => formattedPath case _ => new URI(formattedPath).getPath } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b43e68e40f791..0c7d247519447 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -142,6 +142,8 @@ object SparkSubmit { printErrorAndExit("Cluster deploy mode is currently not supported for python applications.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + case (_, CLUSTER) if isSqlShell(args.mainClass) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") case _ => } @@ -340,11 +342,16 @@ object SparkSubmit { e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { println(s"Failed to load main class $childMainClass.") - println("You need to build Spark with -Phive.") + println("You need to build Spark with -Phive and -Phive-thriftserver.") } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } + // SPARK-4170 + if (classOf[scala.App].isAssignableFrom(mainClass)) { + printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + } + val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) if (!Modifier.isStatic(mainMethod.getModifiers)) { throw new IllegalStateException("The main method in the given main class must be static") @@ -388,6 +395,13 @@ object SparkSubmit { primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL } + /** + * Return whether the given main class represents a sql shell. + */ + private[spark] def isSqlShell(mainClass: String): Boolean = { + mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" + } + /** * Return whether the given primary resource requires running python. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 2b894a796c8c6..d2687faad62b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -129,6 +129,16 @@ private[spark] object SparkSubmitDriverBootstrapper { val process = builder.start() + // If we kill an app while it's running, its sub-process should be killed too. + Runtime.getRuntime().addShutdownHook(new Thread() { + override def run() = { + if (process != null) { + process.destroy() + process.waitFor() + } + } + }) + // Redirect stdout and stderr from the child JVM val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout") val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr") @@ -139,14 +149,15 @@ private[spark] object SparkSubmitDriverBootstrapper { // subprocess there already reads directly from our stdin, so we should avoid spawning a // thread that contends with the subprocess in reading from System.in. val isWindows = Utils.isWindows - val isPySparkShell = sys.env.contains("PYSPARK_SHELL") + val isSubprocess = sys.env.contains("IS_SUBPROCESS") if (!isWindows) { val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin") stdinThread.start() - // For the PySpark shell, Spark submit itself runs as a python subprocess, and so this JVM - // should terminate on broken pipe, which signals that the parent process has exited. In - // Windows, the termination logic for the PySpark shell is handled in java_gateway.py - if (isPySparkShell) { + // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on + // broken pipe, signaling that the parent process has exited. This is the case if the + // application is launched directly from python, as in the PySpark shell. In Windows, + // the termination logic is handled in java_gateway.py + if (isSubprocess) { stdinThread.join() process.destroy() } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2d1609b973607..82a54dbfb5330 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -29,22 +29,27 @@ import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.Utils +/** + * A class that provides application history from event logs stored in the file system. + * This provider checks for new finished applications in the background periodically and + * renders the history application UI by parsing the associated event logs. + */ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider with Logging { + import FsHistoryProvider._ + private val NOT_STARTED = "" // Interval between each check for event log updates private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", conf.getInt("spark.history.updateInterval", 10)) * 1000 - private val logDir = conf.get("spark.history.fs.logDirectory", null) - private val resolvedLogDir = Option(logDir) - .map { d => Utils.resolveURI(d) } - .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") } + private val logDir = conf.getOption("spark.history.fs.logDirectory") + .map { d => Utils.resolveURI(d).toString } + .getOrElse(DEFAULT_LOG_DIR) - private val fs = Utils.getHadoopFileSystem(resolvedLogDir, - SparkHadoopUtil.get.newConfiguration(conf)) + private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) // A timestamp of when the disk was last accessed to check for log updates private var lastLogCheckTimeMs = -1L @@ -87,14 +92,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private def initialize() { // Validate the log directory. - val path = new Path(resolvedLogDir) + val path = new Path(logDir) if (!fs.exists(path)) { - throw new IllegalArgumentException( - "Logging directory specified does not exist: %s".format(resolvedLogDir)) + var msg = s"Log directory specified does not exist: $logDir." + if (logDir == DEFAULT_LOG_DIR) { + msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + } + throw new IllegalArgumentException(msg) } if (!fs.getFileStatus(path).isDir) { throw new IllegalArgumentException( - "Logging directory specified is not a directory: %s".format(resolvedLogDir)) + "Logging directory specified is not a directory: %s".format(logDir)) } checkForLogs() @@ -134,8 +142,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } - override def getConfig(): Map[String, String] = - Map("Event Log Location" -> resolvedLogDir.toString) + override def getConfig(): Map[String, String] = Map("Event log directory" -> logDir.toString) /** * Builds the application list based on the current contents of the log directory. @@ -146,7 +153,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis lastLogCheckTimeMs = getMonotonicTimeMs() logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) try { - val logStatus = fs.listStatus(new Path(resolvedLogDir)) + val logStatus = fs.listStatus(new Path(logDir)) val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() // Load all new logs from the log directory. Only directories that have a modification time @@ -244,6 +251,10 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } +private object FsHistoryProvider { + val DEFAULT_LOG_DIR = "file:/tmp/spark-events" +} + private class FsApplicationHistoryInfo( val logDir: String, id: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0e249e51a77d8..5fdc350cd8512 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -58,7 +58,13 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { ++ appTable } else { -

No Completed Applications Found

+

No completed applications found!

++ +

Did you specify the correct logging directory? + Please verify your setting of + spark.history.fs.logDirectory and whether you have the permissions to + access it.
It is also possible that your application did not run to + completion or did not stop the SparkContext. +

} } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 5bce32a04d16d..b1270ade9f750 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -17,14 +17,13 @@ package org.apache.spark.deploy.history -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) { - private var logDir: String = null +private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { private var propertiesFile: String = null parse(args.toList) @@ -32,7 +31,8 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] private def parse(args: List[String]): Unit = { args match { case ("--dir" | "-d") :: value :: tail => - logDir = value + logWarning("Setting log directory through the command line is deprecated as of " + + "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.") conf.set("spark.history.fs.logDirectory", value) System.setProperty("spark.history.fs.logDirectory", value) parse(tail) @@ -78,9 +78,10 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] | (default 50) |FsHistoryProvider options: | - | spark.history.fs.logDirectory Directory where app logs are stored (required) - | spark.history.fs.updateInterval How often to reload log data from storage (in seconds, - | default 10) + | spark.history.fs.logDirectory Directory where app logs are stored + | (default: file:/tmp/spark-events) + | spark.history.fs.updateInterval How often to reload log data from storage + | (in seconds, default: 10) |""".stripMargin) System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 6ba395be1cc2c..ad7d81747c377 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import akka.actor.ActorRef +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala index 2ac21186881fa..9d3d7938c6ccb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.master import java.util.Date +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.DriverDescription import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 08a99bbe68578..36a2e2c6a6349 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -19,10 +19,13 @@ package org.apache.spark.deploy.master import java.io._ +import scala.reflect.ClassTag + import akka.serialization.Serialization import org.apache.spark.Logging + /** * Stores data in a single on-disk directory with one file per application and worker. * Files are deleted when applications and workers are removed. @@ -37,51 +40,24 @@ private[spark] class FileSystemPersistenceEngine( new File(dir).mkdir() - override def addApplication(app: ApplicationInfo) { - val appFile = new File(dir + File.separator + "app_" + app.id) - serializeIntoFile(appFile, app) - } - - override def removeApplication(app: ApplicationInfo) { - new File(dir + File.separator + "app_" + app.id).delete() - } - - override def addDriver(driver: DriverInfo) { - val driverFile = new File(dir + File.separator + "driver_" + driver.id) - serializeIntoFile(driverFile, driver) - } - - override def removeDriver(driver: DriverInfo) { - new File(dir + File.separator + "driver_" + driver.id).delete() - } - - override def addWorker(worker: WorkerInfo) { - val workerFile = new File(dir + File.separator + "worker_" + worker.id) - serializeIntoFile(workerFile, worker) + override def persist(name: String, obj: Object): Unit = { + serializeIntoFile(new File(dir + File.separator + name), obj) } - override def removeWorker(worker: WorkerInfo) { - new File(dir + File.separator + "worker_" + worker.id).delete() + override def unpersist(name: String): Unit = { + new File(dir + File.separator + name).delete() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - val sortedFiles = new File(dir).listFiles().sortBy(_.getName) - val appFiles = sortedFiles.filter(_.getName.startsWith("app_")) - val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) - val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_")) - val drivers = driverFiles.map(deserializeFromFile[DriverInfo]) - val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_")) - val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) - (apps, drivers, workers) + override def read[T: ClassTag](prefix: String) = { + val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) + files.map(deserializeFromFile[T]) } private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) try { out.write(serialized) @@ -90,7 +66,7 @@ private[spark] class FileSystemPersistenceEngine( } } - def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = { + private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) try { @@ -98,9 +74,9 @@ private[spark] class FileSystemPersistenceEngine( } finally { dis.close() } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) serializer.fromBinary(fileData).asInstanceOf[T] } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index 4433a2ec29be6..cf77c86d760cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -17,30 +17,27 @@ package org.apache.spark.deploy.master -import akka.actor.{Actor, ActorRef} - -import org.apache.spark.deploy.master.MasterMessages.ElectedLeader +import org.apache.spark.annotation.DeveloperApi /** - * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it - * is the only Master serving requests. - * In addition to the API provided, the LeaderElectionAgent will use of the following messages - * to inform the Master of leader changes: - * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]] - * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]] + * :: DeveloperApi :: + * + * A LeaderElectionAgent tracks current master and is a common interface for all election Agents. */ -private[spark] trait LeaderElectionAgent extends Actor { - // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring. - val masterActor: ActorRef +@DeveloperApi +trait LeaderElectionAgent { + val masterActor: LeaderElectable + def stop() {} // to avoid noops in implementations. } -/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ -private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent { - override def preStart() { - masterActor ! ElectedLeader - } +@DeveloperApi +trait LeaderElectable { + def electedLeader() + def revokedLeadership() +} - override def receive = { - case _ => - } +/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ +private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable) + extends LeaderElectionAgent { + masterActor.electedLeader() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 2f81d472d7b78..7b32c505def9b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -30,6 +30,7 @@ import scala.util.Random import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -50,7 +51,7 @@ private[spark] class Master( port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { + extends Actor with ActorLogReceive with Logging with LeaderElectable { import context.dispatcher // to use Akka's scheduler.schedule() @@ -61,7 +62,6 @@ private[spark] class Master( val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) - val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") val workers = new HashSet[WorkerInfo] @@ -103,7 +103,7 @@ private[spark] class Master( var persistenceEngine: PersistenceEngine = _ - var leaderElectionAgent: ActorRef = _ + var leaderElectionAgent: LeaderElectionAgent = _ private var recoveryCompletionTask: Cancellable = _ @@ -130,23 +130,27 @@ private[spark] class Master( masterMetricsSystem.start() applicationMetricsSystem.start() - persistenceEngine = RECOVERY_MODE match { + val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") - new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf) + val zkFactory = + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => - logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) - new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system)) + val fsFactory = + new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) + case "CUSTOM" => + val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) + val factory = clazz.getConstructor(conf.getClass, Serialization.getClass) + .newInstance(conf, SerializationExtension(context.system)) + .asInstanceOf[StandaloneRecoveryModeFactory] + (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => - new BlackHolePersistenceEngine() + (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this)) } - - leaderElectionAgent = RECOVERY_MODE match { - case "ZOOKEEPER" => - context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf)) - case _ => - context.actorOf(Props(classOf[MonarchyLeaderAgent], self)) - } + persistenceEngine = persistenceEngine_ + leaderElectionAgent = leaderElectionAgent_ } override def preRestart(reason: Throwable, message: Option[Any]) { @@ -165,7 +169,15 @@ private[spark] class Master( masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() - context.stop(leaderElectionAgent) + leaderElectionAgent.stop() + } + + override def electedLeader() { + self ! ElectedLeader + } + + override def revokedLeadership() { + self ! RevokedLeadership } override def receiveWithLogging = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index e3640ea4f7e64..2e0e1e7036ac8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -17,6 +17,10 @@ package org.apache.spark.deploy.master +import org.apache.spark.annotation.DeveloperApi + +import scala.reflect.ClassTag + /** * Allows Master to persist any state that is necessary in order to recover from a failure. * The following semantics are required: @@ -25,36 +29,70 @@ package org.apache.spark.deploy.master * Given these two requirements, we will have all apps and workers persisted, but * we might not have yet deleted apps or workers that finished (so their liveness must be verified * during recovery). + * + * The implementation of this trait defines how name-object pairs are stored or retrieved. */ -private[spark] trait PersistenceEngine { - def addApplication(app: ApplicationInfo) +@DeveloperApi +trait PersistenceEngine { - def removeApplication(app: ApplicationInfo) + /** + * Defines how the object is serialized and persisted. Implementation will + * depend on the store used. + */ + def persist(name: String, obj: Object) - def addWorker(worker: WorkerInfo) + /** + * Defines how the object referred by its name is removed from the store. + */ + def unpersist(name: String) - def removeWorker(worker: WorkerInfo) + /** + * Gives all objects, matching a prefix. This defines how objects are + * read/deserialized back. + */ + def read[T: ClassTag](prefix: String): Seq[T] - def addDriver(driver: DriverInfo) + final def addApplication(app: ApplicationInfo): Unit = { + persist("app_" + app.id, app) + } - def removeDriver(driver: DriverInfo) + final def removeApplication(app: ApplicationInfo): Unit = { + unpersist("app_" + app.id) + } + + final def addWorker(worker: WorkerInfo): Unit = { + persist("worker_" + worker.id, worker) + } + + final def removeWorker(worker: WorkerInfo): Unit = { + unpersist("worker_" + worker.id) + } + + final def addDriver(driver: DriverInfo): Unit = { + persist("driver_" + driver.id, driver) + } + + final def removeDriver(driver: DriverInfo): Unit = { + unpersist("driver_" + driver.id) + } /** * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) + final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } def close() {} } private[spark] class BlackHolePersistenceEngine extends PersistenceEngine { - override def addApplication(app: ApplicationInfo) {} - override def removeApplication(app: ApplicationInfo) {} - override def addWorker(worker: WorkerInfo) {} - override def removeWorker(worker: WorkerInfo) {} - override def addDriver(driver: DriverInfo) {} - override def removeDriver(driver: DriverInfo) {} - - override def readPersistedData() = (Nil, Nil, Nil) + + override def persist(name: String, obj: Object): Unit = {} + + override def unpersist(name: String): Unit = {} + + override def read[T: ClassTag](name: String): Seq[T] = Nil + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala new file mode 100644 index 0000000000000..1096eb0368357 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import akka.serialization.Serialization + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.annotation.DeveloperApi + +/** + * ::DeveloperApi:: + * + * Implementation of this class can be plugged in as recovery mode alternative for Spark's + * Standalone mode. + * + */ +@DeveloperApi +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { + + /** + * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) + * is handled for recovery. + * + */ + def createPersistenceEngine(): PersistenceEngine + + /** + * Create an instance of LeaderAgent that decides who gets elected as master. + */ + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent +} + +/** + * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual + * recovery is made by restoring from filesystem. + */ +private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) + extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") + + def createPersistenceEngine() = { + logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) + new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) + } + + def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) +} + +private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) + extends StandaloneRecoveryModeFactory(conf, serializer) { + def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) + + def createLeaderElectionAgent(master: LeaderElectable) = + new ZooKeeperLeaderElectionAgent(master, conf) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index d221b0f6cc86b..473ddc23ff0f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import akka.actor.ActorRef +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils private[spark] class WorkerInfo( diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 285f9b014e291..8eaa0ad948519 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -24,9 +24,8 @@ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} -private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, - masterUrl: String, conf: SparkConf) - extends LeaderElectionAgent with LeaderLatchListener with Logging { +private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, + conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" @@ -34,30 +33,21 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, private var leaderLatch: LeaderLatch = _ private var status = LeadershipStatus.NOT_LEADER - override def preStart() { + start() + def start() { logInfo("Starting ZooKeeper LeaderElection agent") zk = SparkCuratorUtil.newClient(conf) leaderLatch = new LeaderLatch(zk, WORKING_DIR) leaderLatch.addListener(this) - leaderLatch.start() } - override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) { - logError("LeaderElectionAgent failed...", reason) - super.preRestart(reason, message) - } - - override def postStop() { + override def stop() { leaderLatch.close() zk.close() } - override def receive = { - case _ => - } - override def isLeader() { synchronized { // could have lost leadership by now. @@ -85,10 +75,10 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, def updateLeadershipStatus(isLeader: Boolean) { if (isLeader && status == LeadershipStatus.NOT_LEADER) { status = LeadershipStatus.LEADER - masterActor ! ElectedLeader + masterActor.electedLeader() } else if (!isLeader && status == LeadershipStatus.LEADER) { status = LeadershipStatus.NOT_LEADER - masterActor ! RevokedLeadership + masterActor.revokedLeadership() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 834dfedee52ce..e11ac031fb9c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,15 +17,18 @@ package org.apache.spark.deploy.master +import akka.serialization.Serialization + import scala.collection.JavaConversions._ +import scala.reflect.ClassTag -import akka.serialization.Serialization import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} -class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) + +private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { @@ -34,52 +37,31 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) SparkCuratorUtil.mkdir(zk, WORKING_DIR) - override def addApplication(app: ApplicationInfo) { - serializeIntoFile(WORKING_DIR + "/app_" + app.id, app) - } - override def removeApplication(app: ApplicationInfo) { - zk.delete().forPath(WORKING_DIR + "/app_" + app.id) + override def persist(name: String, obj: Object): Unit = { + serializeIntoFile(WORKING_DIR + "/" + name, obj) } - override def addDriver(driver: DriverInfo) { - serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver) + override def unpersist(name: String): Unit = { + zk.delete().forPath(WORKING_DIR + "/" + name) } - override def removeDriver(driver: DriverInfo) { - zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id) - } - - override def addWorker(worker: WorkerInfo) { - serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker) - } - - override def removeWorker(worker: WorkerInfo) { - zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id) + override def read[T: ClassTag](prefix: String) = { + val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) + file.map(deserializeFromFile[T]).flatten } override def close() { zk.close() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted - val appFiles = sortedFiles.filter(_.startsWith("app_")) - val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten - val driverFiles = sortedFiles.filter(_.startsWith("driver_")) - val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten - val workerFiles = sortedFiles.filter(_.startsWith("worker_")) - val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten - (apps, drivers, workers) - } - private def serializeIntoFile(path: String, value: AnyRef) { val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) } - def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = { + def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala index 88118e2837741..b9798963bab0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -39,8 +39,8 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf) - private val blockHandler = new ExternalShuffleBlockHandler() + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val blockHandler = new ExternalShuffleBlockHandler(transportConf) private val transportContext: TransportContext = { val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler new TransportContext(transportConf, handler) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ca262de832e25..eb11163538b20 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,7 +21,6 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} -import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap @@ -177,6 +176,9 @@ private[spark] class Worker( throw new SparkException("Invalid spark URL: " + x) } connected = true + // Cancel any outstanding re-registration attempts because we found a new master + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = None } private def tryRegisterAllMasters() { @@ -187,7 +189,12 @@ private[spark] class Worker( } } - private def retryConnectToMaster() { + /** + * Re-register with the master because a network failure or a master failure has occurred. + * If the re-registration attempt threshold is exceeded, the worker exits with error. + * Note that for thread-safety this should only be called from the actor. + */ + private def reregisterWithMaster(): Unit = { Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { @@ -195,12 +202,40 @@ private[spark] class Worker( registrationRetryTimer = None } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") - tryRegisterAllMasters() + /** + * Re-register with the active master this worker has been communicating with. If there + * is none, then it means this worker is still bootstrapping and hasn't established a + * connection with a master yet, in which case we should re-register with all masters. + * + * It is important to re-register only with the active master during failures. Otherwise, + * if the worker unconditionally attempts to re-register with all masters, the following + * race condition may arise and cause a "duplicate worker" error detailed in SPARK-4592: + * + * (1) Master A fails and Worker attempts to reconnect to all masters + * (2) Master B takes over and notifies Worker + * (3) Worker responds by registering with Master B + * (4) Meanwhile, Worker's previous reconnection attempt reaches Master B, + * causing the same Worker to register with Master B twice + * + * Instead, if we only register with the known active master, we can assume that the + * old master must have died because another master has taken over. Note that this is + * still not safe if the old master recovers within this interval, but this is a much + * less likely scenario. + */ + if (master != null) { + master ! RegisterWorker( + workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + } else { + // We are retrying the initial registration + tryRegisterAllMasters() + } + // We have exceeded the initial registration retry threshold + // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { registrationRetryTimer.foreach(_.cancel()) registrationRetryTimer = Some { context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster) + PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) } } } else { @@ -220,7 +255,7 @@ private[spark] class Worker( connectionAttemptCount = 0 registrationRetryTimer = Some { context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster) + INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) } case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + @@ -400,12 +435,15 @@ private[spark] class Worker( logInfo(s"$x Disassociated !") masterDisconnected() - case RequestWorkerState => { + case RequestWorkerState => sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, drivers.values.toList, finishedDrivers.values.toList, activeMasterUrl, cores, memory, coresUsed, memoryUsed, activeMasterWebUiUrl) - } + + case ReregisterWithMaster => + reregisterWithMaster() + } private def masterDisconnected() { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 3711824a40cfc..5f46f3b1f085e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -57,9 +57,9 @@ private[spark] class CoarseGrainedExecutorBackend( override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") - // Make this host instead of hostPort ? val (hostname, _) = Utils.parseHostPort(hostPort) - executor = new Executor(executorId, hostname, sparkProperties, isLocal = false, actorSystem) + executor = new Executor(executorId, hostname, sparkProperties, cores, isLocal = false, + actorSystem) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index caf4d76713d49..835157fc520aa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -43,6 +43,7 @@ private[spark] class Executor( executorId: String, slaveHostname: String, properties: Seq[(String, String)], + numCores: Int, isLocal: Boolean = false, actorSystem: ActorSystem = null) extends Logging @@ -83,7 +84,7 @@ private[spark] class Executor( if (!isLocal) { val port = conf.getInt("spark.executor.port", 0) val _env = SparkEnv.createExecutorEnv( - conf, executorId, slaveHostname, port, isLocal, actorSystem) + conf, executorId, slaveHostname, port, numCores, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) _env.blockManager.initialize(conf.getAppId) @@ -220,7 +221,7 @@ private[spark] class Executor( // directSend = sending directly back to the driver val serializedResult = { - if (resultSize > maxResultSize) { + if (maxResultSize > 0 && resultSize > maxResultSize) { logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + s"dropping it.") @@ -333,7 +334,7 @@ private[spark] class Executor( * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) synchronized { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index bca0b152268ad..f15e6bc33fb41 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.executor import java.nio.ByteBuffer +import scala.collection.JavaConversions._ + import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary} import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} @@ -50,14 +52,23 @@ private[spark] class MesosExecutorBackend executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, slaveInfo: SlaveInfo) { - logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) + + // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. + val cpusPerTask = executorInfo.getResourcesList + .find(_.getName == "cpus") + .map(_.getScalar.getValue.toInt) + .getOrElse(0) + val executorId = executorInfo.getExecutorId.getValue + + logInfo(s"Registered with Mesos as executor ID $executorId with $cpusPerTask cpus") this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++ Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue)) executor = new Executor( - executorInfo.getExecutorId.getValue, + executorId, slaveInfo.getHostname, - properties) + properties, + cpusPerTask) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 183bce3d8d8d3..d3601cca832b2 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -19,14 +19,13 @@ package org.apache.spark.input import scala.collection.JavaConversions._ +import org.apache.hadoop.conf.{Configuration, Configurable} import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.InputSplit import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader -import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit /** * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for @@ -34,17 +33,24 @@ import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit * the value is the entire content of file. */ -private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] { +private[spark] class WholeTextFileInputFormat + extends CombineFileInputFormat[String, String] with Configurable { + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf + override def createRecordReader( split: InputSplit, context: TaskAttemptContext): RecordReader[String, String] = { - new CombineFileRecordReader[String, String]( - split.asInstanceOf[CombineFileSplit], - context, - classOf[WholeTextFileRecordReader]) + val reader = new WholeCombineFileRecordReader(split, context) + reader.setConf(conf) + reader } /** diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index 3564ab2e2a162..6d59b24eb0596 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -17,11 +17,13 @@ package org.apache.spark.input +import org.apache.hadoop.conf.{Configuration, Configurable} import com.google.common.io.{ByteStreams, Closeables} import org.apache.hadoop.io.Text +import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.InputSplit -import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader} import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -34,7 +36,13 @@ private[spark] class WholeTextFileRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) - extends RecordReader[String, String] { + extends RecordReader[String, String] with Configurable { + + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf private[this] val path = split.getPath(index) private[this] val fs = path.getFileSystem(context.getConfiguration) @@ -57,8 +65,16 @@ private[spark] class WholeTextFileRecordReader( override def nextKeyValue(): Boolean = { if (!processed) { + val conf = new Configuration + val factory = new CompressionCodecFactory(conf) + val codec = factory.getCodec(path) // infers from file ext. val fileIn = fs.open(path) - val innerBuffer = ByteStreams.toByteArray(fileIn) + val innerBuffer = if (codec != null) { + ByteStreams.toByteArray(codec.createInputStream(fileIn)) + } else { + ByteStreams.toByteArray(fileIn) + } + value = new Text(innerBuffer).toString Closeables.close(fileIn, false) processed = true @@ -68,3 +84,33 @@ private[spark] class WholeTextFileRecordReader( } } } + + +/** + * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file + * out in a key-value pair, where the key is the file path and the value is the entire content of + * the file. + */ +private[spark] class WholeCombineFileRecordReader( + split: InputSplit, + context: TaskAttemptContext) + extends CombineFileRecordReader[String, String]( + split.asInstanceOf[CombineFileSplit], + context, + classOf[WholeTextFileRecordReader] + ) with Configurable { + + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf + + override def initNextRecordReader(): Boolean = { + val r = super.initNextRecordReader() + if (r) { + this.curReader.asInstanceOf[WholeTextFileRecordReader].setConf(conf) + } + r + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index f8a7f640689a2..0027cbb0ff1fb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -35,13 +35,13 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager) +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int) extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index 9fa4fa77b8817..cef203006d685 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -21,12 +21,53 @@ import org.apache.spark.SparkConf import org.apache.spark.network.util.{TransportConf, ConfigProvider} /** - * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor, + * Driver, or a standalone shuffle service) into a TransportConf with details on our environment + * like the number of cores that are allocated to this JVM. */ object SparkTransportConf { - def fromSparkConf(conf: SparkConf): TransportConf = { + /** + * Specifies an upper bound on the number of Netty threads that Spark requires by default. + * In practice, only 2-4 cores should be required to transfer roughly 10 Gb/s, and each core + * that we use will have an initial overhead of roughly 32 MB of off-heap memory, which comes + * at a premium. + * + * Thus, this value should still retain maximum throughput and reduce wasted off-heap memory + * allocation. It can be overridden by setting the number of serverThreads and clientThreads + * manually in Spark's configuration. + */ + private val MAX_DEFAULT_NETTY_THREADS = 8 + + /** + * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param numUsableCores if nonzero, this will restrict the server and client threads to only + * use the given number of cores, rather than all of the machine's cores. + * This restriction will only occur if these properties are not already set. + */ + def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + val conf = _conf.clone + + // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily + // assuming we have all the machine's cores). + // NB: Only set if serverThreads/clientThreads not already set. + val numThreads = defaultNumThreads(numUsableCores) + conf.set("spark.shuffle.io.serverThreads", + conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) + conf.set("spark.shuffle.io.clientThreads", + conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + new TransportConf(new ConfigProvider { override def get(name: String): String = conf.get(name) }) } + + /** + * Returns the default number of threads for both the Netty client and server thread pools. + * If numUsableCores is 0, we will use Runtime get an approximate number of available cores. + */ + private def defaultNumThreads(numUsableCores: Int): Int = { + val availableCores = + if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() + math.min(availableCores, MAX_DEFAULT_NETTY_THREADS) + } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index f198aa8564a54..df4b085d2251e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -18,13 +18,13 @@ package org.apache.spark.network.nio import java.io.IOException +import java.lang.ref.WeakReference import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} -import java.util.{Timer, TimerTask} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ @@ -32,6 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 +import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import org.apache.spark._ import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} @@ -77,7 +78,8 @@ private[nio] class ConnectionManager( } private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) + private val ackTimeoutMonitor = + new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) @@ -139,7 +141,10 @@ private[nio] class ConnectionManager( new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - private val messageStatuses = new HashMap[Int, MessageStatus] + // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this + // map when messages are sent and are removed when acknowledgement messages are received or when + // acknowledgement timeouts expire + private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] @@ -899,22 +904,41 @@ private[nio] class ConnectionManager( : Future[Message] = { val promise = Promise[Message]() - val timeoutTask = new TimerTask { - override def run(): Unit = { + // It's important that the TimerTask doesn't capture a reference to `message`, which can cause + // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time + // at which they would originally be scheduled to run. Therefore, extract the message id + // from outside of the TimerTask closure (see SPARK-4393 for more context). + val messageId = message.id + // Keep a weak reference to the promise so that the completed promise may be garbage-collected + val promiseReference = new WeakReference(promise) + val timeoutTask: TimerTask = new TimerTask { + override def run(timeout: Timeout): Unit = { messageStatuses.synchronized { - messageStatuses.remove(message.id).foreach ( s => { + messageStatuses.remove(messageId).foreach { s => val e = new IOException("sendMessageReliably failed because ack " + s"was not received within $ackTimeout sec") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) + val p = promiseReference.get + if (p != null) { + // Attempt to fail the promise with a Timeout exception + if (!p.tryFailure(e)) { + // If we reach here, then someone else has already signalled success or failure + // on this promise, so log a warning: + logError("Ignore error because promise is completed", e) + } + } else { + // The WeakReference was empty, which should never happen because + // sendMessageReliably's caller should have a strong reference to promise.future; + logError("Promise was garbage collected; this should never happen!", e) } - }) + } } } } + val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) + val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTask.cancel() + timeoutTaskHandle.cancel() s match { case scala.util.Failure(e) => // Indicates a failure where we either never sent or never got ACK'd @@ -943,7 +967,6 @@ private[nio] class ConnectionManager( messageStatuses += ((message.id, status)) } - ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } @@ -953,7 +976,7 @@ private[nio] class ConnectionManager( } def stop() { - ackTimeoutMonitor.cancel() + ackTimeoutMonitor.stop() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index e2fc9c649925e..436dbed1730bc 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -44,5 +44,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.2.0-SNAPSHOT" + val SPARK_VERSION = "1.3.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0e38f224ac81d..642a12c1edf6c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet} import scala.reflect.ClassTag -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.util.NextIterator +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { override def index = idx @@ -125,5 +128,82 @@ object JdbcRDD { def resultSetToObjectArray(rs: ResultSet): Array[Object] = { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } -} + trait ConnectionFactory extends Serializable { + @throws[Exception] + def getConnection: Connection + } + + /** + * Create an RDD that executes an SQL query on a JDBC connection and reads results. + * For usage example, see test case JavaAPISuite.testJavaJdbcRDD. + * + * @param connectionFactory a factory that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ + def create[T]( + sc: JavaSparkContext, + connectionFactory: ConnectionFactory, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: JFunction[ResultSet, T]): JavaRDD[T] = { + + val jdbcRDD = new JdbcRDD[T]( + sc.sc, + () => connectionFactory.getConnection, + sql, + lowerBound, + upperBound, + numPartitions, + (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag) + + new JavaRDD[T](jdbcRDD)(fakeClassTag) + } + + /** + * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is + * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD. + * + * @param connectionFactory a factory that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + */ + def create( + sc: JavaSparkContext, + connectionFactory: ConnectionFactory, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): JavaRDD[Array[Object]] = { + + val mapRow = new JFunction[ResultSet, Array[Object]] { + override def call(resultSet: ResultSet): Array[Object] = { + resultSetToObjectArray(resultSet) + } + } + + create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 56ac7a69be0d3..ed79032893d33 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -63,7 +63,7 @@ private[spark] class PipedRDD[T: ClassTag]( /** * A FilenameFilter that accepts anything that isn't equal to the name passed in. - * @param name of file or directory to leave out + * @param filterName of file or directory to leave out */ class NotEqualsFileNameFilter(filterName: String) extends FilenameFilter { def accept(dir: File, name: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 716f2dd17733b..3add4a76192ca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -21,6 +21,7 @@ import java.util.{Properties, Random} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer +import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus @@ -28,6 +29,7 @@ import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ @@ -1202,7 +1204,7 @@ abstract class RDD[T: ClassTag]( */ def checkpoint() { if (context.checkpointDir.isEmpty) { - throw new Exception("Checkpoint directory has not been set in the SparkContext") + throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() @@ -1309,7 +1311,7 @@ abstract class RDD[T: ClassTag]( def debugSelf (rdd: RDD[_]): Seq[String] = { import Utils.bytesToString - val persistence = storageLevel.description + val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), @@ -1383,3 +1385,31 @@ abstract class RDD[T: ClassTag]( new JavaRDD(this)(elementClassTag) } } + +object RDD { + + // The following implicit functions were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, we still keep the old functions in SparkContext for backward + // compatibility and forward to the following functions directly. + + implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + new PairRDDFunctions(rdd) + } + + implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + rdd: RDD[(K, V)]) = + new SequenceFileRDDFunctions(rdd) + + implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( + rdd: RDD[(K, V)]) = + new OrderedRDDFunctions[K, V, (K, V)](rdd) + + implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + + implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index e2c301603b4a5..8c43a559409f2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) private[spark] class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { - override def getPartitions: Array[Partition] = { + /** The start index of each partition. */ + @transient private val startIndices: Array[Long] = { val n = prev.partitions.size - val startIndices: Array[Long] = - if (n == 0) { - Array[Long]() - } else if (n == 1) { - Array(0L) - } else { - prev.context.runJob( - prev, - Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - false - ).scanLeft(0L)(_ + _) - } + if (n == 0) { + Array[Long]() + } else if (n == 1) { + Array(0L) + } else { + prev.context.runJob( + prev, + Utils.getIteratorSize _, + 0 until n - 1, // do not need to count the last partition + allowLocal = false + ).scanLeft(0L)(_ + _) + } + } + + override def getPartitions: Array[Partition] = { firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index))) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22449517d100f..cb8ccfbdbdcbb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -449,7 +449,6 @@ class DAGScheduler( } // data structures based on StageId stageIdToStage -= stageId - logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -751,14 +750,15 @@ class DAGScheduler( localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) + listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties)) runLocally(job) } else { jobIdToActiveJob(jobId) = job activeJobs += job finalStage.resultOfJob = Some(job) - listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray, - properties)) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties)) submitStage(finalStage) } } @@ -901,6 +901,34 @@ class DAGScheduler( } } + /** Merge updates from a task to our local accumulator values */ + private def updateAccumulators(event: CompletionEvent): Unit = { + val task = event.task + val stage = stageIdToStage(task.stageId) + if (event.accumUpdates != null) { + try { + Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => + val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name.get + val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) + val stringValue = Accumulators.stringifyValue(acc.value) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + event.taskInfo.accumulables += + AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + } + } + } catch { + // If we see an exception during accumulator update, just log the + // error and move on. + case e: Exception => + logError(s"Failed to update accumulators for $task", e) + } + } + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -941,27 +969,6 @@ class DAGScheduler( } event.reason match { case Success => - if (event.accumUpdates != null) { - try { - Accumulators.add(event.accumUpdates) - event.accumUpdates.foreach { case (id, partialValue) => - val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] - // To avoid UI cruft, ignore cases where value wasn't updated - if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) - event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) - } - } - } catch { - // If we see an exception during accumulator update, just log the error and move on. - case e: Exception => - logError(s"Failed to update accumulators for $task", e) - } - } listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) stage.pendingTasks -= task @@ -970,6 +977,7 @@ class DAGScheduler( stage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { + updateAccumulators(event) job.finished(rt.outputId) = true job.numFinished += 1 // If the whole job has finished, remove it @@ -994,6 +1002,7 @@ class DAGScheduler( } case smt: ShuffleMapTask => + updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) @@ -1082,7 +1091,6 @@ class DAGScheduler( } failedStages += failedStage failedStages += mapStage - // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { mapStage.removeOutputLoc(mapId, bmAddress) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 86afe3bd5265f..b62b0c1312693 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -56,8 +56,15 @@ case class SparkListenerTaskEnd( extends SparkListenerEvent @DeveloperApi -case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null) - extends SparkListenerEvent +case class SparkListenerJobStart( + jobId: Int, + stageInfos: Seq[StageInfo], + properties: Properties = null) + extends SparkListenerEvent { + // Note: this is here for backwards-compatibility with older versions of this event which + // only stored stageIds and not StageInfos: + val stageIds: Seq[Int] = stageInfos.map(_.stageId) +} @DeveloperApi case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d8fb640350343..cabdc655f89bf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -536,7 +536,7 @@ private[spark] class TaskSetManager( calculatedTasks += 1 if (maxResultSize > 0 && totalResultSize > maxResultSize) { val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + - s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than spark.driver.maxResultSize " + s"(${Utils.bytesToString(maxResultSize)})" logError(msg) abort(msg) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7a6ee56f81689..88b196ac64368 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -46,6 +46,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) + // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) @@ -126,7 +127,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste makeOffers() case KillTask(taskId, executorId, interruptThread) => - executorDataMap(executorId).executorActor ! KillTask(taskId, executorId, interruptThread) + executorDataMap.get(executorId) match { + case Some(executorInfo) => + executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + case None => + // Ignoring the task kill since the executor is not registered. + logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") + } case StopDriver => sender ! true @@ -204,6 +211,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste executorsPendingToRemove -= executorId } totalCoreCount.addAndGet(-executorInfo.totalCores) + totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) case None => logError(s"Asked to remove non-existent executor $executorId") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c5f3493477bc5..10e6886c16a4f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -166,29 +166,16 @@ private[spark] class MesosSchedulerBackend( execArgs } - private def setClassLoader(): ClassLoader = { - val oldClassLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(classLoader) - oldClassLoader - } - - private def restoreClassLoader(oldClassLoader: ClassLoader) { - Thread.currentThread.setContextClassLoader(oldClassLoader) - } - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) registeredLock.synchronized { isRegistered = true registeredLock.notifyAll() } - } finally { - restoreClassLoader(oldClassLoader) } } @@ -200,6 +187,16 @@ private[spark] class MesosSchedulerBackend( } } + private def inClassLoader()(fun: => Unit) = { + val oldClassLoader = Thread.currentThread.getContextClassLoader + Thread.currentThread.setContextClassLoader(classLoader) + try { + fun + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) + } + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} @@ -210,66 +207,70 @@ private[spark] class MesosSchedulerBackend( * tasks are balanced across the cluster. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - val oldClassLoader = setClassLoader() - try { - synchronized { - // Build a big list of the offerable workers, and remember their indices so that we can - // figure out which Offer to reply to for each worker - val offerableWorkers = new ArrayBuffer[WorkerOffer] - val offerableIndices = new HashMap[String, Int] - - def sufficientOffer(o: Offer) = { - val mem = getResource(o.getResourcesList, "mem") - val cpus = getResource(o.getResourcesList, "cpus") - val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= 2 * scheduler.CPUS_PER_TASK) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) - } + inClassLoader() { + // Fail-fast on offers we know will be rejected + val (usableOffers, unUsableOffers) = offers.partition { o => + val mem = getResource(o.getResourcesList, "mem") + val cpus = getResource(o.getResourcesList, "cpus") + val slaveId = o.getSlaveId.getValue + // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK? + (mem >= MemoryUtils.calculateTotalMemory(sc) && + // need at least 1 for executor, 1 for task + cpus >= 2 * scheduler.CPUS_PER_TASK) || + (slaveIdsWithExecutors.contains(slaveId) && + cpus >= scheduler.CPUS_PER_TASK) + } - for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) { - val slaveId = offer.getSlaveId.getValue - offerableIndices.put(slaveId, index) - val cpus = if (slaveIdsWithExecutors.contains(slaveId)) { - getResource(offer.getResourcesList, "cpus").toInt - } else { - // If the executor doesn't exist yet, subtract CPU for executor - getResource(offer.getResourcesList, "cpus").toInt - - scheduler.CPUS_PER_TASK - } - offerableWorkers += new WorkerOffer( - offer.getSlaveId.getValue, - offer.getHostname, - cpus) + val workerOffers = usableOffers.map { o => + val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { + getResource(o.getResourcesList, "cpus").toInt + } else { + // If the executor doesn't exist yet, subtract CPU for executor + // TODO(pwendell): Should below just subtract "1"? + getResource(o.getResourcesList, "cpus").toInt - + scheduler.CPUS_PER_TASK } + new WorkerOffer( + o.getSlaveId.getValue, + o.getHostname, + cpus) + } + + val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap + + val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] - // Call into the TaskSchedulerImpl - val taskLists = scheduler.resourceOffers(offerableWorkers) - - // Build a list of Mesos tasks for each slave - val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]()) - for ((taskList, index) <- taskLists.zipWithIndex) { - if (!taskList.isEmpty) { - for (taskDesc <- taskList) { - val slaveId = taskDesc.executorId - val offerNum = offerableIndices(slaveId) - slaveIdsWithExecutors += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) - } + val slavesIdsOfAcceptedOffers = HashSet[String]() + + // Call into the TaskSchedulerImpl + val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) + acceptedOffers + .foreach { offer => + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } } - // Reply to the offers - val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? - for (i <- 0 until offers.size) { - d.launchTasks(Collections.singleton(offers(i).getId), mesosTasks(i), filters) - } + // Reply to the offers + val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? + + mesosTasks.foreach { case (slaveId, tasks) => + d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } - } finally { - restoreClassLoader(oldClassLoader) + + // Decline offers that weren't used + // NOTE: This logic assumes that we only get a single offer for each host in a given batch + for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { + d.declineOffer(o.getId) + } + + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } @@ -308,8 +309,7 @@ private[spark] class MesosSchedulerBackend( } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) synchronized { @@ -322,18 +322,13 @@ private[spark] class MesosSchedulerBackend( } } scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) - } finally { - restoreClassLoader(oldClassLoader) } } override def error(d: SchedulerDriver, message: String) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { logError("Mesos error: " + message) scheduler.error(message) - } finally { - restoreClassLoader(oldClassLoader) } } @@ -350,15 +345,12 @@ private[spark] class MesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) synchronized { slaveIdsWithExecutors -= slaveId.getValue } scheduler.executorLost(slaveId.getValue, reason) - } finally { - restoreClassLoader(oldClassLoader) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index c0264836de738..a2f1f14264a99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -51,7 +51,7 @@ private[spark] class LocalActor( private val localExecutorHostname = "localhost" val executor = new Executor( - localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) + localExecutorId, localExecutorHostname, scheduler.conf.getAll, totalCores, isLocal = true) override def receiveWithLogging = { case ReviveOffers => diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index f03e8e4bf1b7e..7de2f9cbb2866 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ @@ -68,6 +69,8 @@ private[spark] class FileShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager with Logging { + private val transportConf = SparkTransportConf.fromSparkConf(conf) + private lazy val blockManager = SparkEnv.get.blockManager // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. @@ -182,13 +185,14 @@ class FileShuffleBlockManager(conf: SparkConf) val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) if (segmentOpt.isDefined) { val segment = segmentOpt.get - return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length) + return new FileSegmentManagedBuffer( + transportConf, segment.file, segment.offset, segment.length) } } throw new IllegalStateException("Failed to find shuffle block: " + blockId) } else { val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(file, 0, file.length) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index a48f0c9eceb5e..b292587d37028 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -22,8 +22,9 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ /** @@ -38,10 +39,12 @@ import org.apache.spark.storage._ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { private lazy val blockManager = SparkEnv.get.blockManager + private val transportConf = SparkTransportConf.fromSparkConf(conf) + /** * Mapping to a single shuffleBlockId with reduce ID 0. * */ @@ -109,6 +112,7 @@ class IndexShuffleBlockManager extends ShuffleBlockManager { val offset = in.readLong() val nextOffset = in.readLong() new FileSegmentManagedBuffer( + transportConf, getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b727438ae7e47..bda30a56d808e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { - private val indexShuffleBlockManager = new IndexShuffleBlockManager() + private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 39434f473a9d8..308c59eda594d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -73,7 +73,8 @@ private[spark] class BlockManager( mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService, - securityManager: SecurityManager) + securityManager: SecurityManager, + numUsableCores: Int) extends BlockDataManager with Logging { val diskBlockManager = new DiskBlockManager(this, conf) @@ -121,8 +122,8 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager, - securityManager.isAuthenticationEnabled()) + val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) } else { blockTransferService } @@ -174,9 +175,10 @@ private[spark] class BlockManager( mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService, - securityManager: SecurityManager) = { + securityManager: SecurityManager, + numUsableCores: Int) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) + conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 6b1f57a069431..83170f7c5a4ab 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -265,7 +265,7 @@ final class ShuffleBlockFetcherIterator( // Get Local Blocks fetchLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime)) } override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 6908a59a79e60..af873034215a9 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -148,6 +148,7 @@ private[spark] class TachyonBlockManager( logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } + client.close() } }) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala index 6dbad5ff0518e..233d1e2b7c616 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala @@ -116,6 +116,8 @@ private[spark] class TachyonStore( case ioe: IOException => logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) None + } finally { + is.close() } } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala new file mode 100644 index 0000000000000..27ba9e18237b5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import java.util.{Timer, TimerTask} + +import org.apache.spark._ + +/** + * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the + * status of active stages from `sc.statusTracker` periodically, the progress bar will be showed + * up after the stage has ran at least 500ms. If multiple stages run in the same time, the status + * of them will be combined together, showed in one line. + */ +private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { + + // Carrige return + val CR = '\r' + // Update period of progress bar, in milliseconds + val UPDATE_PERIOD = 200L + // Delay to show up a progress bar, in milliseconds + val FIRST_DELAY = 500L + + // The width of terminal + val TerminalWidth = if (!sys.env.getOrElse("COLUMNS", "").isEmpty) { + sys.env.get("COLUMNS").get.toInt + } else { + 80 + } + + var lastFinishTime = 0L + var lastUpdateTime = 0L + var lastProgressBar = "" + + // Schedule a refresh thread to run periodically + private val timer = new Timer("refresh progress", true) + timer.schedule(new TimerTask{ + override def run() { + refresh() + } + }, FIRST_DELAY, UPDATE_PERIOD) + + /** + * Try to refresh the progress bar in every cycle + */ + private def refresh(): Unit = synchronized { + val now = System.currentTimeMillis() + if (now - lastFinishTime < FIRST_DELAY) { + return + } + val stageIds = sc.statusTracker.getActiveStageIds() + val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) + .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) + if (stages.size > 0) { + show(now, stages.take(3)) // display at most 3 stages in same time + } + } + + /** + * Show progress bar in console. The progress bar is displayed in the next line + * after your last output, keeps overwriting itself to hold in one line. The logging will follow + * the progress bar, then progress bar will be showed in next line without overwrite logs. + */ + private def show(now: Long, stages: Seq[SparkStageInfo]) { + val width = TerminalWidth / stages.size + val bar = stages.map { s => + val total = s.numTasks() + val header = s"[Stage ${s.stageId()}:" + val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" + val w = width - header.size - tailer.size + val bar = if (w > 0) { + val percent = w * s.numCompletedTasks() / total + (0 until w).map { i => + if (i < percent) "=" else if (i == percent) ">" else " " + }.mkString("") + } else { + "" + } + header + bar + tailer + }.mkString("") + + // only refresh if it's changed of after 1 minute (or the ssh connection will be closed + // after idle some time) + if (bar != lastProgressBar || now - lastUpdateTime > 60 * 1000L) { + System.err.print(CR + bar) + lastUpdateTime = now + } + lastProgressBar = bar + } + + /** + * Clear the progress bar if showed. + */ + private def clear() { + if (!lastProgressBar.isEmpty) { + System.err.printf(CR + " " * TerminalWidth + CR) + lastProgressBar = "" + } + } + + /** + * Mark all the stages as finished, clear the progress bar if showed, then the progress will not + * interweave with output of jobs. + */ + def finishAll(): Unit = synchronized { + clear() + lastFinishTime = System.currentTimeMillis() + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 049938f827291..176907dffa46a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -23,7 +23,7 @@ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} -import org.apache.spark.ui.jobs.{JobProgressListener, JobProgressTab} +import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} import org.apache.spark.ui.storage.{StorageListener, StorageTab} /** @@ -43,17 +43,20 @@ private[spark] class SparkUI private ( extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging { + val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) + /** Initialize all components of the server. */ def initialize() { - val jobProgressTab = new JobProgressTab(this) - attachTab(jobProgressTab) + attachTab(new JobsTab(this)) + val stagesTab = new StagesTab(this) + attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/stages", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest)) + createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) // If the UI is live, then serve sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 3312671b6f885..315327c3c6b7c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -26,7 +26,8 @@ import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS = "table table-bordered table-striped-custom table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { @@ -169,15 +170,19 @@ private[spark] object UIUtils extends Logging { title: String, content: => Seq[Node], activeTab: SparkUITab, - refreshInterval: Option[Int] = None): Seq[Node] = { + refreshInterval: Option[Int] = None, + helpText: Option[String] = None): Seq[Node] = { val appName = activeTab.appName val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } + val helpButton: Seq[Node] = helpText.map { helpText => + (?) + }.getOrElse(Seq.empty) @@ -201,11 +206,17 @@ private[spark] object UIUtils extends Logging {

    {title} + {helpButton}

    {content} + } @@ -232,6 +243,11 @@ private[spark] object UIUtils extends Logging { {content} + } @@ -243,12 +259,10 @@ private[spark] object UIUtils extends Logging { data: Iterable[T], fixedWidth: Boolean = false, id: Option[String] = None, - headerClasses: Seq[String] = Seq.empty): Seq[Node] = { + headerClasses: Seq[String] = Seq.empty, + stripeRowsWithCss: Boolean = true): Seq[Node] = { - var listingTableClass = TABLE_CLASS - if (fixedWidth) { - listingTableClass += " table-fixed" - } + val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" @@ -283,4 +297,24 @@ private[spark] object UIUtils extends Logging { } + + def makeProgressBar( + started: Int, + completed: Int, + failed: Int, + skipped:Int, + total: Int): Seq[Node] = { + val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) + val startWidth = "width: %s%%".format((started.toDouble/total)*100) + +
    + + {completed}/{total} + { if (failed > 0) s"($failed failed)" } + { if (skipped > 0) s"($skipped skipped)" } + +
    +
    +
    + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index e9c755e36f716..c82730f524eb7 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.net.URLDecoder import javax.servlet.http.HttpServletRequest import scala.util.Try @@ -29,7 +30,19 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage private val sc = parent.sc def render(request: HttpServletRequest): Seq[Node] = { - val executorId = Option(request.getParameter("executorId")).getOrElse { + val executorId = Option(request.getParameter("executorId")).map { + executorId => + // Due to YARN-2844, "" in the url will be encoded to "%25253Cdriver%25253E" when + // running in yarn-cluster mode. `request.getParameter("executorId")` will return + // "%253Cdriver%253E". Therefore we need to decode it until we get the real id. + var id = executorId + var decodedId = URLDecoder.decode(id, "UTF-8") + while (id != decodedId) { + id = decodedId + decodedId = URLDecoder.decode(id, "UTF-8") + } + id + }.getOrElse { return Text(s"Missing executorId parameter") } val time = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 048fee3ce1ff4..363cb96de7998 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.net.URLEncoder import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -56,7 +57,7 @@ private[ui] class ExecutorsPage( val execInfoSorted = execInfo.sortBy(_.id) val execTable = - +
    @@ -139,8 +140,9 @@ private[ui] class ExecutorsPage( { if (threadDumpEnabled) { + val encodedId = URLEncoder.encode(info.id, "UTF-8") } else { Seq.empty diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala new file mode 100644 index 0000000000000..ea2d187a0e8e4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.xml.{Node, NodeSeq} + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.jobs.UIData.JobUIData + +/** Page showing list of all ongoing and recently finished jobs */ +private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { + private val startTime: Option[Long] = parent.sc.map(_.startTime) + private val listener = parent.listener + + private def jobsTable(jobs: Seq[JobUIData]): Seq[Node] = { + val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) + + val columns: Seq[Node] = { + + + + + + + } + + def makeRow(job: JobUIData): Seq[Node] = { + val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageData = lastStageInfo.flatMap { s => + listener.stageIdToData.get((s.stageId, s.attemptId)) + } + val isComplete = job.status == JobExecutionStatus.SUCCEEDED + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") + val duration: Option[Long] = { + job.startTime.map { start => + val end = job.endTime.getOrElse(System.currentTimeMillis()) + end - start + } + } + val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") + val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown") + val detailUrl = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId) + + + + + + + + + } + +
    Executor ID Address - Thread Dump + Thread Dump {if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"}DescriptionSubmittedDurationStages: Succeeded/TotalTasks (for all stages): Succeeded/Total
    + {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} + +
    {lastStageDescription}
    + {lastStageName} +
    + {formattedSubmissionTime} + {formattedDuration} + {job.completedStageIndices.size}/{job.stageIds.size - job.numSkippedStages} + {if (job.numFailedStages > 0) s"(${job.numFailedStages} failed)"} + {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"} + + {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + total = job.numTasks - job.numSkippedTasks)} +
    + {columns} + + {jobs.map(makeRow)} + +
    + } + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val activeJobs = listener.activeJobs.values.toSeq + val completedJobs = listener.completedJobs.reverse.toSeq + val failedJobs = listener.failedJobs.reverse.toSeq + val now = System.currentTimeMillis + + val activeJobsTable = + jobsTable(activeJobs.sortBy(_.startTime.getOrElse(-1L)).reverse) + val completedJobsTable = + jobsTable(completedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + val failedJobsTable = + jobsTable(failedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + + val summary: NodeSeq = +
    +
      + {if (startTime.isDefined) { + // Total duration is not meaningful unless the UI is live +
    • + Total Duration: + {UIUtils.formatDuration(now - startTime.get)} +
    • + }} +
    • + Scheduling Mode: + {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} +
    • +
    • + Active Jobs: + {activeJobs.size} +
    • +
    • + Completed Jobs: + {completedJobs.size} +
    • +
    • + Failed Jobs: + {failedJobs.size} +
    • +
    +
    + + val content = summary ++ +

    Active Jobs ({activeJobs.size})

    ++ activeJobsTable ++ +

    Completed Jobs ({completedJobs.size})

    ++ completedJobsTable ++ +

    Failed Jobs ({failedJobs.size})

    ++ failedJobsTable + + val helpText = """A job is triggered by a action, like "count()" or "saveAsTextFile()".""" + + " Click on a job's title to see information about the stages of tasks associated with" + + " the job." + + UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala similarity index 87% rename from core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala rename to core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 83a7898071c9b..b0f8ca2ab0d3f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.Schedulable import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ -private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { +private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc private val listener = parent.listener private def isFairScheduler = parent.isFairScheduler @@ -41,11 +41,14 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, - parent, parent.killEnabled) + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) val completedStagesTable = - new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent) + new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) val failedStagesTable = - new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent) + new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler) // For now, pool information is only accessible in live UIs val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) @@ -93,7 +96,7 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")

    Failed Stages ({numFailedStages})

    ++ failedStagesTable.toNodeSeq - UIUtils.headerSparkPage("Spark Stages", content, parent) + UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index fa0f96bff34ff..9836d11a6d85f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobProgressTab) { +private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: StagesTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { @@ -36,7 +36,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr /** Special table which merges two header cells. */ private def executorTable[T](): Seq[Node] = { - +
    diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala new file mode 100644 index 0000000000000..77d36209c6048 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.collection.mutable +import scala.xml.{NodeSeq, Node} + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.scheduler.StageInfo +import org.apache.spark.ui.{UIUtils, WebUIPage} + +/** Page showing statistics and stage list for a given job */ +private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { + private val listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val jobId = request.getParameter("id").toInt + val jobDataOption = listener.jobIdToData.get(jobId) + if (jobDataOption.isEmpty) { + val content = +
    +

    No information to display for job {jobId}

    +
    + return UIUtils.headerSparkPage( + s"Details for Job $jobId", content, parent) + } + val jobData = jobDataOption.get + val isComplete = jobData.status != JobExecutionStatus.RUNNING + val stages = jobData.stageIds.map { stageId => + // This could be empty if the JobProgressListener hasn't received information about the + // stage or if the stage information has been garbage collected + listener.stageIdToInfo.getOrElse(stageId, + new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, "Unknown")) + } + + val activeStages = mutable.Buffer[StageInfo]() + val completedStages = mutable.Buffer[StageInfo]() + // If the job is completed, then any pending stages are displayed as "skipped": + val pendingOrSkippedStages = mutable.Buffer[StageInfo]() + val failedStages = mutable.Buffer[StageInfo]() + for (stage <- stages) { + if (stage.submissionTime.isEmpty) { + pendingOrSkippedStages += stage + } else if (stage.completionTime.isDefined) { + if (stage.failureReason.isDefined) { + failedStages += stage + } else { + completedStages += stage + } + } else { + activeStages += stage + } + } + + val activeStagesTable = + new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) + val pendingOrSkippedStagesTable = + new StageTableBase(pendingOrSkippedStages.sortBy(_.stageId).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = false) + val completedStagesTable = + new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) + val failedStagesTable = + new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler) + + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + + val summary: NodeSeq = +
    +
      +
    • + Status: + {jobData.status} +
    • + { + if (jobData.jobGroup.isDefined) { +
    • + Job Group: + {jobData.jobGroup.get} +
    • + } + } + { + if (shouldShowActiveStages) { +
    • + Active Stages: + {activeStages.size} +
    • + } + } + { + if (shouldShowPendingStages) { +
    • + + Pending Stages: + {pendingOrSkippedStages.size} +
    • + } + } + { + if (shouldShowCompletedStages) { +
    • + Completed Stages: + {completedStages.size} +
    • + } + } + { + if (shouldShowSkippedStages) { +
    • + Skipped Stages: + {pendingOrSkippedStages.size} +
    • + } + } + { + if (shouldShowFailedStages) { +
    • + Failed Stages: + {failedStages.size} +
    • + } + } +
    +
    + + var content = summary + if (shouldShowActiveStages) { + content ++=

    Active Stages ({activeStages.size})

    ++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

    Pending Stages ({pendingOrSkippedStages.size})

    ++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

    Completed Stages ({completedStages.size})

    ++ + completedStagesTable.toNodeSeq + } + if (shouldShowSkippedStages) { + content ++=

    Skipped Stages ({pendingOrSkippedStages.size})

    ++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

    Failed Stages ({failedStages.size})

    ++ + failedStagesTable.toNodeSeq + } + UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 8bbde51e1801c..72935beb3a34a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -40,48 +40,145 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { import JobProgressListener._ + // Define a handful of type aliases so that data structures' types can serve as documentation. + // These type aliases are public because they're used in the types of public fields: + type JobId = Int type StageId = Int type StageAttemptId = Int + type PoolName = String + type ExecutorId = String - // How many stages to remember - val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) - // How many jobs to remember - val retailedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS) - + // Jobs: val activeJobs = new HashMap[JobId, JobUIData] val completedJobs = ListBuffer[JobUIData]() val failedJobs = ListBuffer[JobUIData]() val jobIdToData = new HashMap[JobId, JobUIData] + // Stages: val activeStages = new HashMap[StageId, StageInfo] val completedStages = ListBuffer[StageInfo]() + val skippedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] val stageIdToInfo = new HashMap[StageId, StageInfo] - - // Number of completed and failed stages, may not actually equal to completedStages.size and - // failedStages.size respectively due to completedStage and failedStages only maintain the latest - // part of the stages, the earlier ones will be removed when there are too many stages for - // memory sake. + val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]] + val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]() + // Total of completed and failed stages that have ever been run. These may be greater than + // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than + // JobProgressListener's retention limits. var numCompletedStages = 0 var numFailedStages = 0 - // Map from pool name to a hash map (map from stage id to StageInfo). - val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() - - val executorIdToBlockManagerId = HashMap[String, BlockManagerId]() + // Misc: + val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() + def blockManagerIds = executorIdToBlockManagerId.values.toSeq var schedulingMode: Option[SchedulingMode] = None - def blockManagerIds = executorIdToBlockManagerId.values.toSeq + // To limit the total memory usage of JobProgressListener, we only track information for a fixed + // number of non-active jobs and stages (there is no limit for active jobs and stages): + + val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) + val retainedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS) + + // We can test for memory leaks by ensuring that collections that track non-active jobs and + // stages do not grow without bound and that collections for active jobs/stages eventually become + // empty once Spark is idle. Let's partition our collections into ones that should be empty + // once Spark is idle and ones that should have a hard- or soft-limited sizes. + // These methods are used by unit tests, but they're defined here so that people don't forget to + // update the tests when adding new collections. Some collections have multiple levels of + // nesting, etc, so this lets us customize our notion of "size" for each structure: + + // These collections should all be empty once Spark is idle (no active stages / jobs): + private[spark] def getSizesOfActiveStateTrackingCollections: Map[String, Int] = { + Map( + "activeStages" -> activeStages.size, + "activeJobs" -> activeJobs.size, + "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum, + "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum + ) + } + + // These collections should stop growing once we have run at least `spark.ui.retainedStages` + // stages and `spark.ui.retainedJobs` jobs: + private[spark] def getSizesOfHardSizeLimitedCollections: Map[String, Int] = { + Map( + "completedJobs" -> completedJobs.size, + "failedJobs" -> failedJobs.size, + "completedStages" -> completedStages.size, + "skippedStages" -> skippedStages.size, + "failedStages" -> failedStages.size + ) + } + + // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to + // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings: + private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = { + Map( + "jobIdToData" -> jobIdToData.size, + "stageIdToData" -> stageIdToData.size, + "stageIdToStageInfo" -> stageIdToInfo.size + ) + } + + /** If stages is too large, remove and garbage collect old stages */ + private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { + if (stages.size > retainedStages) { + val toRemove = math.max(retainedStages / 10, 1) + stages.take(toRemove).foreach { s => + stageIdToData.remove((s.stageId, s.attemptId)) + stageIdToInfo.remove(s.stageId) + } + stages.trimStart(toRemove) + } + } + + /** If jobs is too large, remove and garbage collect old jobs */ + private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized { + if (jobs.size > retainedJobs) { + val toRemove = math.max(retainedJobs / 10, 1) + jobs.take(toRemove).foreach { job => + jobIdToData.remove(job.jobId) + } + jobs.trimStart(toRemove) + } + } override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { - val jobGroup = Option(jobStart.properties).map(_.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + val jobGroup = for ( + props <- Option(jobStart.properties); + group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + ) yield group val jobData: JobUIData = - new JobUIData(jobStart.jobId, jobStart.stageIds, jobGroup, JobExecutionStatus.RUNNING) + new JobUIData( + jobId = jobStart.jobId, + startTime = Some(System.currentTimeMillis), + endTime = None, + stageIds = jobStart.stageIds, + jobGroup = jobGroup, + status = JobExecutionStatus.RUNNING) + // Compute (a potential underestimate of) the number of tasks that will be run by this job. + // This may be an underestimate because the job start event references all of the result + // stages's transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + jobData.numTasks = { + val allStages = jobStart.stageInfos + val missingStages = allStages.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } jobIdToData(jobStart.jobId) = jobData activeJobs(jobStart.jobId) = jobData + for (stageId <- jobStart.stageIds) { + stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId) + } + // If there's no information for a stage, store the StageInfo received from the scheduler + // so that we can display stage descriptions for pending stages: + for (stageInfo <- jobStart.stageInfos) { + stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo) + stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData) + } } override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { @@ -89,14 +186,31 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) } + jobData.endTime = Some(System.currentTimeMillis()) jobEnd.jobResult match { case JobSucceeded => completedJobs += jobData + trimJobsIfNecessary(completedJobs) jobData.status = JobExecutionStatus.SUCCEEDED case JobFailed(exception) => failedJobs += jobData + trimJobsIfNecessary(failedJobs) jobData.status = JobExecutionStatus.FAILED } + for (stageId <- jobData.stageIds) { + stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => + jobsUsingStage.remove(jobEnd.jobId) + stageIdToInfo.get(stageId).foreach { stageInfo => + if (stageInfo.submissionTime.isEmpty) { + // if this stage is pending, it won't complete, so mark it as "skipped": + skippedStages += stageInfo + trimStagesIfNecessary(skippedStages) + jobData.numSkippedStages += 1 + jobData.numSkippedTasks += stageInfo.numTasks + } + } + } + } } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { @@ -118,23 +232,24 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { if (stage.failureReason.isEmpty) { completedStages += stage numCompletedStages += 1 - trimIfNecessary(completedStages) + trimStagesIfNecessary(completedStages) } else { failedStages += stage numFailedStages += 1 - trimIfNecessary(failedStages) + trimStagesIfNecessary(failedStages) } - } - /** If stages is too large, remove and garbage collect old stages */ - private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { - if (stages.size > retainedStages) { - val toRemove = math.max(retainedStages / 10, 1) - stages.take(toRemove).foreach { s => - stageIdToData.remove((s.stageId, s.attemptId)) - stageIdToInfo.remove(s.stageId) + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages -= 1 + if (stage.failureReason.isEmpty) { + jobData.completedStageIndices.add(stage.stageId) + } else { + jobData.numFailedStages += 1 } - stages.trimStart(toRemove) } } @@ -157,6 +272,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]) stages(stage.stageId) = stage + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages += 1 + } } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { @@ -169,6 +292,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.numActiveTasks += 1 stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) } + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks += 1 + } } override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { @@ -226,6 +356,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskData.taskInfo = info taskData.taskMetrics = metrics taskData.errorMessage = errorMessage + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks -= 1 + taskEnd.reason match { + case Success => + jobData.numCompletedTasks += 1 + case _ => + jobData.numFailedTasks += 1 + } + } } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala similarity index 58% rename from graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala rename to core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 49b2704390fea..b2bbfdee56946 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -15,23 +15,18 @@ * limitations under the License. */ -package org.apache.spark.graphx.impl +package org.apache.spark.ui.jobs -import scala.reflect.ClassTag -import scala.util.Random +import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.ui.{SparkUI, SparkUITab} -import org.scalatest.FunSuite +/** Web UI showing progress status of all jobs in the given SparkContext. */ +private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { + val sc = parent.sc + val killEnabled = parent.killEnabled + def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + val listener = parent.jobProgressListener -import org.apache.spark.graphx._ - -class EdgeTripletIteratorSuite extends FunSuite { - test("iterator.toList") { - val builder = new EdgePartitionBuilder[Int, Int] - builder.add(1, 2, 0) - builder.add(1, 3, 0) - builder.add(1, 4, 0) - val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true) - val result = iter.toList.map(et => (et.srcId, et.dstId)) - assert(result === Seq((1, 2), (1, 3), (1, 4))) - } + attachPage(new AllJobsPage(this)) + attachPage(new JobPage(this)) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 770d99eea1c9d..5fc6cc7533150 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ -private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { +private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { private val sc = parent.sc private val listener = parent.listener @@ -37,8 +37,9 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { case Some(s) => s.values.toSeq case None => Seq[StageInfo]() } - val activeStagesTable = - new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent) + val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) // For now, pool information is only accessible in live UIs val pools = sc.map(_.getPoolForName(poolName).get).toSeq diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 64178e1e33d41..df1899e7a9b84 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -24,7 +24,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.UIUtils /** Table showing list of pools */ -private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { +private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 16bc3f6c18d09..bfa54f8492068 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.{Utils, Distribution} import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} /** Page showing statistics and task list for a given stage */ -private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { +private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -114,6 +114,10 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { ++ @@ -73,25 +72,11 @@ private[ui] class StageTableBase(
    Executor ID Address Stage Id
    } - private def makeProgressBar(started: Int, completed: Int, failed: Int, total: Int): Seq[Node] = - { - val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) - -
    - - {completed}/{total} { if (failed > 0) s"($failed failed)" else "" } - -
    -
    -
    - } - private def makeDescription(s: StageInfo): Seq[Node] = { // scalastyle:off val killLink = if (killEnabled) { val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" - .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + .format(UIUtils.prependBaseUri(basePath), s.stageId) val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" .format(s.stageId) @@ -101,7 +86,7 @@ private[ui] class StageTableBase( // scalastyle:on val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" - .format(UIUtils.prependBaseUri(parent.basePath), s.stageId, s.attemptId) + .format(UIUtils.prependBaseUri(basePath), s.stageId, s.attemptId) val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -115,7 +100,7 @@ private[ui] class StageTableBase( Text("RDD: ") ++ // scalastyle:off cachedRddInfos.map { i => - {i.name} + {i.name} } // scalastyle:on }} @@ -167,7 +152,7 @@ private[ui] class StageTableBase( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(basePath), stageData.schedulingPool)}> {stageData.schedulingPool} @@ -175,11 +160,14 @@ private[ui] class StageTableBase( Seq.empty }} ++ {makeDescription(s)} - {submissionTime} + + {submissionTime} + {formattedDuration} - {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, - stageData.numFailedTasks, s.numTasks)} + {UIUtils.makeProgressBar(started = stageData.numActiveTasks, + completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, + skipped = 0, total = s.numTasks)} {inputReadWithUnit} {outputWriteWithUnit} @@ -193,9 +181,10 @@ private[ui] class StageTableBase( private[ui] class FailedStageTable( stages: Seq[StageInfo], - parent: JobProgressTab, - killEnabled: Boolean = false) - extends StageTableBase(stages, parent, killEnabled) { + basePath: String, + listener: JobProgressListener, + isFairScheduler: Boolean) + extends StageTableBase(stages, basePath, listener, isFairScheduler, killEnabled = false) { override protected def columns: Seq[Node] = super.columns ++ Failure Reason diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala rename to core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 03ca918e2e8b3..937261de00e3a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -19,18 +19,16 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest -import org.apache.spark.SparkConf import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.ui.{SparkUI, SparkUITab} -/** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") { +/** Web UI showing progress status of all stages in the given SparkContext. */ +private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { val sc = parent.sc - val conf = sc.map(_.conf).getOrElse(new SparkConf) - val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) + val killEnabled = parent.killEnabled val listener = parent.jobProgressListener - attachPage(new JobProgressPage(this)) + attachPage(new AllStagesPage(this)) attachPage(new StagePage(this)) attachPage(new PoolPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index eb371bd0ea7ed..ca942c4051c84 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -20,6 +20,9 @@ package org.apache.spark.ui.jobs /** * Names of the CSS classes corresponding to each type of task detail. Used to allow users * to optionally show/hide columns. + * + * If new optional metrics are added here, they should also be added to the end of webui.css + * to have the style set to "display: none;" by default. */ private object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 2f7d618df5f6f..48fd7caa1a1ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -40,9 +40,28 @@ private[jobs] object UIData { class JobUIData( var jobId: Int = -1, + var startTime: Option[Long] = None, + var endTime: Option[Long] = None, var stageIds: Seq[Int] = Seq.empty, var jobGroup: Option[String] = None, - var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN + var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, + /* Tasks */ + // `numTasks` is a potential underestimate of the true number of tasks that this job will run. + // This may be an underestimate because the job start event references all of the result + // stages's transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + var numTasks: Int = 0, + var numActiveTasks: Int = 0, + var numCompletedTasks: Int = 0, + var numSkippedTasks: Int = 0, + var numFailedTasks: Int = 0, + /* Stages */ + var numActiveStages: Int = 0, + // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: + var completedStageIndices: OpenHashSet[Int] = new OpenHashSet[Int](), + var numSkippedStages: Int = 0, + var numFailedStages: Int = 0 ) class StageUIData { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 10010bdfa1a51..8c2457f56bffe 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -134,9 +134,16 @@ private[spark] object AkkaUtils extends Logging { Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") } + private val AKKA_MAX_FRAME_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 + /** Returns the configured max frame size for Akka messages in bytes. */ def maxFrameSizeBytes(conf: SparkConf): Int = { - conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024 + val frameSizeInMB = conf.getInt("spark.akka.frameSize", 10) + if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) { + throw new IllegalArgumentException("spark.akka.frameSize should not be greater than " + + AKKA_MAX_FRAME_SIZE_IN_MB + "MB") + } + frameSizeInMB * 1024 * 1024 } /** Space reserved for extra data in an Akka message besides serialized task or task result. */ diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 7e536edfe807b..e7b80e8774b9c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -31,6 +31,21 @@ import org.apache.spark.scheduler._ import org.apache.spark.storage._ import org.apache.spark._ +/** + * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- + * and forwards-compatibility guarantees: any version of Spark should be able to read JSON output + * written by any other version, including newer versions. + * + * JsonProtocolSuite contains backwards-compatibility tests which check that the current version of + * JsonProtocol is able to read output written by earlier versions. We do not currently have tests + * for reading newer JSON output with older Spark versions. + * + * To ensure that we provide these guarantees, follow these rules when modifying these methods: + * + * - Never delete any JSON fields. + * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields + * in `*FromJson` methods. + */ private[spark] object JsonProtocol { // TODO: Remove this file and put JSON serialization into each individual class. @@ -121,6 +136,7 @@ private[spark] object JsonProtocol { val properties = propertiesToJson(jobStart.properties) ("Event" -> Utils.getFormattedClassName(jobStart)) ~ ("Job ID" -> jobStart.jobId) ~ + ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 ("Stage IDs" -> jobStart.stageIds) ~ ("Properties" -> properties) } @@ -455,7 +471,12 @@ private[spark] object JsonProtocol { val jobId = (json \ "Job ID").extract[Int] val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") - SparkListenerJobStart(jobId, stageIds, properties) + // The "Stage Infos" field was added in Spark 1.2.0 + val stageInfos = Utils.jsonOption(json \ "Stage Infos") + .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { + stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) + } + SparkListenerJobStart(jobId, stageInfos, properties) } def jobEndFromJson(json: JValue): SparkListenerJobEnd = { @@ -667,6 +688,10 @@ private[spark] object JsonProtocol { } def blockManagerIdFromJson(json: JValue): BlockManagerId = { + // On metadata fetch fail, block manager ID can be null (SPARK-4471) + if (json == JNothing) { + return null + } val executorId = (json \ "Executor ID").extract[String] val host = (json \ "Host").extract[String] val port = (json \ "Port").extract[Int] diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2889e171f627e..ac40f19ed6799 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -52,7 +52,7 @@ private[spark] class MetadataCleaner( logDebug( "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + "and period of " + periodSeconds + " secs") - timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + timer.schedule(task, delaySeconds * 1000, periodSeconds * 1000) } def cancel() { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index eb4a598dbf857..336b0798cade9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -262,7 +262,7 @@ private[spark] object Utils extends Logging { if (dir.exists() || !dir.mkdirs()) { dir = null } - } catch { case e: IOException => ; } + } catch { case e: SecurityException => dir = null; } } registerShutdownDeleteDir(dir) diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index d44e15e3c97ea..4d43d8d5cc8d8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.collection +import scala.reflect.ClassTag + /** * An append-only buffer similar to ArrayBuffer, but more memory-efficient for small buffers. * ArrayBuffer always allocates an Object array to store the data, with 16 entries by default, @@ -25,7 +27,7 @@ package org.apache.spark.util.collection * entries than that. This makes it more efficient for operations like groupBy where we expect * some keys to have very few elements. */ -private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { +private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable { // First two elements private var element0: T = _ private var element1: T = _ @@ -34,7 +36,7 @@ private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { private var curSize = 0 // Array for extra elements - private var otherElements: Array[AnyRef] = null + private var otherElements: Array[T] = null def apply(position: Int): T = { if (position < 0 || position >= curSize) { @@ -45,7 +47,7 @@ private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { } else if (position == 1) { element1 } else { - otherElements(position - 2).asInstanceOf[T] + otherElements(position - 2) } } @@ -58,7 +60,7 @@ private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { } else if (position == 1) { element1 = value } else { - otherElements(position - 2) = value.asInstanceOf[AnyRef] + otherElements(position - 2) = value } } @@ -72,7 +74,7 @@ private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { curSize = 2 } else { growToSize(curSize + 1) - otherElements(newIndex - 2) = value.asInstanceOf[AnyRef] + otherElements(newIndex - 2) = value } this } @@ -139,7 +141,7 @@ private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { newArrayLen = Int.MaxValue - 2 } } - val newArray = new Array[AnyRef](newArrayLen) + val newArray = new Array[T](newArrayLen) if (otherElements != null) { System.arraycopy(otherElements, 0, newArray, 0, otherElements.length) } @@ -150,9 +152,9 @@ private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { } private[spark] object CompactBuffer { - def apply[T](): CompactBuffer[T] = new CompactBuffer[T] + def apply[T: ClassTag](): CompactBuffer[T] = new CompactBuffer[T] - def apply[T](value: T): CompactBuffer[T] = { + def apply[T: ClassTag](value: T): CompactBuffer[T] = { val buf = new CompactBuffer[T] buf += value } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 26fa0cb6d7bde..8a0f5a602de12 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -76,10 +76,6 @@ class ExternalAppendOnlyMap[K, V, C]( private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - // Number of pairs inserted since last spill; note that we count them even if a value is merged - // with a previous key in case we're doing something like groupBy where the result grows - protected[this] var elementsRead = 0L - /** * Size of object batches when reading/writing from serializers. * @@ -132,7 +128,7 @@ class ExternalAppendOnlyMap[K, V, C]( currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) - elementsRead += 1 + addElementsRead() } } @@ -209,8 +205,6 @@ class ExternalAppendOnlyMap[K, V, C]( } spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - - elementsRead = 0 } def diskBytesSpilled: Long = _diskBytesSpilled diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index c1ce13683b569..15bda1c9cc29c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -119,10 +119,6 @@ private[spark] class ExternalSorter[K, V, C]( private var map = new SizeTrackingAppendOnlyMap[(Int, K), C] private var buffer = new SizeTrackingPairBuffer[(Int, K), C] - // Number of pairs read from input since last spill; note that we count them even if a value is - // merged with a previous key in case we're doing something like groupBy where the result grows - protected[this] var elementsRead = 0L - // Total spilling statistics private var _diskBytesSpilled = 0L @@ -204,15 +200,22 @@ private[spark] class ExternalSorter[K, V, C]( if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) } while (records.hasNext) { - elementsRead += 1 + addElementsRead() kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) } + } else if (bypassMergeSort) { + // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies + if (records.hasNext) { + spillToPartitionFiles(records.map { kv => + ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + }) + } } else { // Stick values into our buffer while (records.hasNext) { - elementsRead += 1 + addElementsRead() val kv = records.next() buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) @@ -340,6 +343,10 @@ private[spark] class ExternalSorter[K, V, C]( * @param collection whichever collection we're using (map or buffer) */ private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + spillToPartitionFiles(collection.iterator) + } + + private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = { assert(bypassMergeSort) // Create our file writers if we haven't done so yet @@ -354,9 +361,9 @@ private[spark] class ExternalSorter[K, V, C]( } } - val it = collection.iterator // No need to sort stuff, just write each element out - while (it.hasNext) { - val elem = it.next() + // No need to sort stuff, just write each element out + while (iterator.hasNext) { + val elem = iterator.next() val partitionId = elem._1._1 val key = elem._1._2 val value = elem._2 @@ -752,6 +759,12 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled context.taskMetrics.diskBytesSpilled += diskBytesSpilled + context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m => + if (curWriteMetrics != null) { + m.shuffleBytesWritten += curWriteMetrics.shuffleBytesWritten + m.shuffleWriteTime += curWriteMetrics.shuffleWriteTime + } + } lengths } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index d7dccd4af8c6e..9f54312074856 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -24,10 +24,7 @@ import org.apache.spark.SparkEnv * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ -private[spark] trait Spillable[C] { - - this: Logging => - +private[spark] trait Spillable[C] extends Logging { /** * Spills the current in-memory collection to disk, and releases the memory. * @@ -36,16 +33,29 @@ private[spark] trait Spillable[C] { protected def spill(collection: C): Unit // Number of elements read from input since last spill - protected var elementsRead: Long + protected def elementsRead: Long = _elementsRead + + // Called by subclasses every time a record is read + // It's used for checking spilling frequency + protected def addElementsRead(): Unit = { _elementsRead += 1 } // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - // What threshold of elementsRead we start estimating collection size at + // Threshold for `elementsRead` before we start tracking this collection's memory usage private[this] val trackMemoryThreshold = 1000 - // How much of the shared memory pool this collection has claimed - private[this] var myMemoryThreshold = 0L + // Initial threshold for the size of a collection before we start tracking its memory usage + // Exposed for testing + private[this] val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) + + // Threshold for this collection's size in bytes before we start tracking its memory usage + // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 + private[this] var myMemoryThreshold = initialMemoryThreshold + + // Number of elements read from input since last spill + private[this] var _elementsRead = 0L // Number of bytes spilled in total private[this] var _memoryBytesSpilled = 0L @@ -76,6 +86,7 @@ private[spark] trait Spillable[C] { spill(collection) + _elementsRead = 0 // Keep track of spills, and release memory _memoryBytesSpilled += currentMemory releaseMemoryForThisThread() @@ -94,8 +105,9 @@ private[spark] trait Spillable[C] { * Release our memory back to the shuffle pool so that other threads can grab it. */ private def releaseMemoryForThisThread(): Unit = { - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0L + // The amount we requested does not include the initial memory tracking threshold + shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) + myMemoryThreshold = initialMemoryThreshold } /** @@ -105,7 +117,8 @@ private[spark] trait Spillable[C] { */ @inline private def logSpillage(size: Long) { val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" - .format(threadId, size / (1024 * 1024), _spillCount, if (_spillCount > 1) "s" else "")) + logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)" + .format(threadId, org.apache.spark.util.Utils.bytesToString(size), + _spillCount, if (_spillCount > 1) "s" else "")) } } diff --git a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java new file mode 100644 index 0000000000000..7fe452a48d89b --- /dev/null +++ b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark; + +import java.io.Serializable; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.rdd.JdbcRDD; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class JavaJdbcRDDSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() throws ClassNotFoundException, SQLException { + sc = new JavaSparkContext("local", "JavaAPISuite"); + + Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); + Connection connection = + DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true"); + + try { + Statement create = connection.createStatement(); + create.execute( + "CREATE TABLE FOO(" + + "ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," + + "DATA INTEGER)"); + create.close(); + + PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)"); + for (int i = 1; i <= 100; i++) { + insert.setInt(1, i * 2); + insert.executeUpdate(); + } + insert.close(); + } catch (SQLException e) { + // If table doesn't exist... + if (e.getSQLState().compareTo("X0Y32") != 0) { + throw e; + } + } finally { + connection.close(); + } + } + + @After + public void tearDown() throws SQLException { + try { + DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;shutdown=true"); + } catch(SQLException e) { + // Throw if not normal single database shutdown + // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html + if (e.getSQLState().compareTo("08006") != 0) { + throw e; + } + } + + sc.stop(); + sc = null; + } + + @Test + public void testJavaJdbcRDD() throws Exception { + JavaRDD rdd = JdbcRDD.create( + sc, + new JdbcRDD.ConnectionFactory() { + @Override + public Connection getConnection() throws SQLException { + return DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb"); + } + }, + "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", + 1, 100, 1, + new Function() { + @Override + public Integer call(ResultSet r) throws Exception { + return r.getInt(1); + } + } + ).cache(); + + Assert.assertEquals(100, rdd.count()); + Assert.assertEquals( + Integer.valueOf(10100), + rdd.reduce(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + })); + } +} diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 66cf60d25f6d1..ce804f94f3267 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -37,20 +37,24 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") intercept[SparkException] { new SparkContext(conf) } SparkEnv.get.stop() // cleanup the created environment + SparkContext.clearActiveContext() // Only min val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1") intercept[SparkException] { new SparkContext(conf1) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Only max val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2") intercept[SparkException] { new SparkContext(conf2) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) @@ -76,6 +80,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("add executors") { sc = createSparkContext(1, 10) val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Keep adding until the limit is reached assert(numExecutorsPending(manager) === 0) @@ -117,6 +122,51 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(numExecutorsToAdd(manager) === 1) } + test("add executors capped by num pending tasks") { + sc = createSparkContext(1, 10) + val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) + + // Verify that we're capped at number of tasks in the stage + assert(numExecutorsPending(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 3) + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 5) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a task reduces the cap + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) + sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 6) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 7) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that re-running a task doesn't reduce the cap further + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 3))) + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 8) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 9) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a task once we're at our limit doesn't blow things up + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) + assert(addExecutors(manager) === 0) + assert(numExecutorsPending(manager) === 9) + } + test("remove executors") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get @@ -170,6 +220,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test ("interleaving add and remove") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Add a few executors assert(addExecutors(manager) === 1) @@ -343,6 +394,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { val clock = new TestClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 6608ed1e57b38..55799f55146cb 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -38,8 +38,8 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf) - rpcHandler = new ExternalShuffleBlockHandler() + val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) + rpcHandler = new ExternalShuffleBlockHandler(transportConf) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index cda942e15a704..5d20b4dc1561a 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -95,14 +95,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 + // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys + val NUM_BLOCKS = 201 val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo - val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS)) .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId @@ -122,13 +122,13 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 + // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys + val NUM_BLOCKS = 201 val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) @@ -270,7 +270,6 @@ object ShuffleSuite { def mergeCombineException(x: Int, y: Int): Int = { throw new SparkException("Exception for map-side combine.") - x + y } class NonJavaSerializableClass(val value: Int) extends Comparable[NonJavaSerializableClass] { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 31edad1c56c73..1362022104195 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -21,15 +21,68 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable -class SparkContextSuite extends FunSuite { - //Regression test for SPARK-3121 +class SparkContextSuite extends FunSuite with LocalSparkContext { + + /** Allows system properties to be changed in tests */ + private def withSystemProperty[T](property: String, value: String)(block: => T): T = { + val originalValue = System.getProperty(property) + try { + System.setProperty(property, value) + block + } finally { + if (originalValue == null) { + System.clearProperty(property) + } else { + System.setProperty(property, originalValue) + } + } + } + + test("Only one SparkContext may be active at a time") { + // Regression test for SPARK-4180 + withSystemProperty("spark.driver.allowMultipleContexts", "false") { + val conf = new SparkConf().setAppName("test").setMaster("local") + sc = new SparkContext(conf) + // A SparkContext is already running, so we shouldn't be able to create a second one + intercept[SparkException] { new SparkContext(conf) } + // After stopping the running context, we should be able to create a new one + resetSparkContext() + sc = new SparkContext(conf) + } + } + + test("Can still construct a new SparkContext after failing to construct a previous one") { + withSystemProperty("spark.driver.allowMultipleContexts", "false") { + // This is an invalid configuration (no app name or master URL) + intercept[SparkException] { + new SparkContext(new SparkConf()) + } + // Even though those earlier calls failed, we should still be able to create a new context + sc = new SparkContext(new SparkConf().setMaster("local").setAppName("test")) + } + } + + test("Check for multiple SparkContexts can be disabled via undocumented debug option") { + withSystemProperty("spark.driver.allowMultipleContexts", "true") { + var secondSparkContext: SparkContext = null + try { + val conf = new SparkConf().setAppName("test").setMaster("local") + sc = new SparkContext(conf) + secondSparkContext = new SparkContext(conf) + } finally { + Option(secondSparkContext).foreach(_.stop()) + } + } + } + test("BytesWritable implicit conversion is correct") { + // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() val inputArray = (1 to 10).map(_.toByte).toArray bytesWritable.set(inputArray, 0, 10) bytesWritable.set(inputArray, 0, 5) - val converter = SparkContext.bytesWritableConverter() + val converter = WritableConverter.bytesWritableConverter() val byteArray = converter.convert(bytesWritable) assert(byteArray.length === 5) diff --git a/core/src/test/scala/org/apache/spark/StatusAPISuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala similarity index 69% rename from core/src/test/scala/org/apache/spark/StatusAPISuite.scala rename to core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 4468fba8c1dff..8577e4ac7e33e 100644 --- a/core/src/test/scala/org/apache/spark/StatusAPISuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -27,9 +27,10 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.JobExecutionStatus._ import org.apache.spark.SparkContext._ -class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { +class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { test("basic status API usage") { + sc = new SparkContext("local", "test", new SparkConf(false)) val jobFuture = sc.parallelize(1 to 10000, 2).map(identity).groupBy(identity).collectAsync() val jobId: Int = eventually(timeout(10 seconds)) { val jobIds = jobFuture.jobIds @@ -37,20 +38,20 @@ class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { jobIds.head } val jobInfo = eventually(timeout(10 seconds)) { - sc.getJobInfo(jobId).get + sc.statusTracker.getJobInfo(jobId).get } jobInfo.status() should not be FAILED val stageIds = jobInfo.stageIds() stageIds.size should be(2) val firstStageInfo = eventually(timeout(10 seconds)) { - sc.getStageInfo(stageIds(0)).get + sc.statusTracker.getStageInfo(stageIds(0)).get } firstStageInfo.stageId() should be(stageIds(0)) firstStageInfo.currentAttemptId() should be(0) firstStageInfo.numTasks() should be(2) eventually(timeout(10 seconds)) { - val updatedFirstStageInfo = sc.getStageInfo(stageIds(0)).get + val updatedFirstStageInfo = sc.statusTracker.getStageInfo(stageIds(0)).get updatedFirstStageInfo.numCompletedTasks() should be(2) updatedFirstStageInfo.numActiveTasks() should be(0) updatedFirstStageInfo.numFailedTasks() should be(0) @@ -58,21 +59,31 @@ class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { } test("getJobIdsForGroup()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + // Passing `null` should return jobs that were not run in a job group: + val defaultJobGroupFuture = sc.parallelize(1 to 1000).countAsync() + val defaultJobGroupJobId = eventually(timeout(10 seconds)) { + defaultJobGroupFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup(null).toSet should be (Set(defaultJobGroupJobId)) + } + // Test jobs submitted in job groups: sc.setJobGroup("my-job-group", "description") - sc.getJobIdsForGroup("my-job-group") should be (Seq.empty) + sc.statusTracker.getJobIdsForGroup("my-job-group") should be (Seq.empty) val firstJobFuture = sc.parallelize(1 to 1000).countAsync() val firstJobId = eventually(timeout(10 seconds)) { firstJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId)) } val secondJobFuture = sc.parallelize(1 to 1000).countAsync() val secondJobId = eventually(timeout(10 seconds)) { secondJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) } } } \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 94a2bdd74e744..d2dae34be7bfb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -23,17 +23,26 @@ import org.scalatest.Matchers class ClientSuite extends FunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) - ClientArguments.isValidJarUrl("file://some/path/to/a/jarFile.jar") should be (true) + + // file scheme with authority and path is valid. + ClientArguments.isValidJarUrl("file://somehost/path/to/a/jarFile.jar") should be (true) + + // file scheme without path is not valid. + // In this case, jarFile.jar is recognized as authority. + ClientArguments.isValidJarUrl("file://jarFile.jar") should be (false) + + // file scheme without authority but with triple slash is valid. + ClientArguments.isValidJarUrl("file:///some/path/to/a/jarFile.jar") should be (true) ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo.jar") should be (true) ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo") should be (false) ClientArguments.isValidJarUrl("/missing/a/protocol/jarfile.jar") should be (false) ClientArguments.isValidJarUrl("not-even-a-path.jar") should be (false) - // No authority + // This URI doesn't have authority and path. ClientArguments.isValidJarUrl("hdfs:someHost:1234/jarfile.jar") should be (false) - // Invalid syntax + // Invalid syntax. ClientArguments.isValidJarUrl("hdfs:") should be (false) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d8cd0ff2c9026..eb7bd7ab3986e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,7 +21,7 @@ import java.io._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkException, TestUtils} +import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite @@ -451,24 +451,25 @@ class SparkSubmitSuite extends FunSuite with Matchers { } } -object JarCreationTest { +object JarCreationTest extends Logging { def main(args: Array[String]) { Utils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => - var foundClasses = false + var exception: String = null try { Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) - foundClasses = true } catch { - case _: Throwable => // catch all + case t: Throwable => + exception = t + "\n" + t.getStackTraceString + exception = exception.replaceAll("\n", "\n\t") } - Seq(foundClasses).iterator + Option(exception).toSeq.iterator }.collect() - if (result.contains(false)) { - throw new Exception("Could not load user defined classes inside of executors") + if (result.nonEmpty) { + throw new Exception("Could not load user class from jar:\n" + result(0)) } } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 12d1c7b2faba6..98b0a16ce88ba 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.io.Text import org.apache.spark.SparkContext import org.apache.spark.util.Utils +import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} /** * Tests the correctness of @@ -38,20 +39,32 @@ import org.apache.spark.util.Utils */ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ + private var factory: CompressionCodecFactory = _ override def beforeAll() { sc = new SparkContext("local", "test") // Set the block size of local file system to test whether files are split right or not. sc.hadoopConfiguration.setLong("fs.local.block.size", 32) + sc.hadoopConfiguration.set("io.compression.codecs", + "org.apache.hadoop.io.compress.GzipCodec,org.apache.hadoop.io.compress.DefaultCodec") + factory = new CompressionCodecFactory(sc.hadoopConfiguration) } override def afterAll() { sc.stop() } - private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte]) = { - val out = new DataOutputStream(new FileOutputStream(s"${inputDir.toString}/$fileName")) + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], + compress: Boolean) = { + val out = if (compress) { + val codec = new GzipCodec + val path = s"${inputDir.toString}/$fileName${codec.getDefaultExtension}" + codec.createOutputStream(new DataOutputStream(new FileOutputStream(path))) + } else { + val path = s"${inputDir.toString}/$fileName" + new DataOutputStream(new FileOutputStream(path)) + } out.write(contents, 0, contents.length) out.close() } @@ -68,7 +81,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { println(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => - createNativeFile(dir, filename, contents) + createNativeFile(dir, filename, contents, false) } val res = sc.wholeTextFiles(dir.toString, 3).collect() @@ -86,6 +99,31 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { Utils.deleteRecursively(dir) } + + test("Correctness of WholeTextFileRecordReader with GzipCodec.") { + val dir = Utils.createTempDir() + println(s"Local disk address is ${dir.toString}.") + + WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, true) + } + + val res = sc.wholeTextFiles(dir.toString, 3).collect() + + assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, + "Number of files read out does not fit with the actual value.") + + for ((filename, contents) <- res) { + val shortName = filename.split('/').last.split('.')(0) + + assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), + s"Missing file name $filename.") + assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, + s"file $filename contents can not match.") + } + + Utils.deleteRecursively(dir) + } } /** diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 530f5d6db5a29..94bfa67451892 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -104,11 +104,11 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) val securityManager0 = new SecurityManager(conf0) - val exec0 = new NettyBlockTransferService(conf0, securityManager0) + val exec0 = new NettyBlockTransferService(conf0, securityManager0, numCores = 1) exec0.init(blockManager) val securityManager1 = new SecurityManager(conf1) - val exec1 = new NettyBlockTransferService(conf1, securityManager1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1, numCores = 1) exec1.init(blockManager) val result = fetchBlock(exec0, exec1, "1", blockId) match { diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 76e317d754ba3..6138d0bbd57f6 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -65,10 +65,11 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { after { try { - DriverManager.getConnection("jdbc:derby:;shutdown=true") + DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true") } catch { - case se: SQLException if se.getSQLState == "XJ015" => - // normal shutdown + case se: SQLException if se.getSQLState == "08006" => + // Normal single database shutdown + // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6d2e696dc2fc4..e079ca3b1e896 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -739,6 +739,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("zipWithIndex chained with other RDDs (SPARK-4433)") { + val count = sc.parallelize(0 until 10, 2).zipWithIndex().repartition(4).count() + assert(count === 10) + } + test("zipWithUniqueId") { val n = 10 val data = sc.parallelize(0 until n, 3) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 819f95634bcdc..bdd721dc7eaf7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -207,7 +207,18 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null)) + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, null, null)) + } + } + } + + private def completeWithAccumulator(accumId: Long, taskSet: TaskSet, + results: Seq[(TaskEndReason, Any)]) { + assert(taskSet.tasks.size >= results.size) + for ((result, i) <- results.zipWithIndex) { + if (i < taskSet.tasks.size) { + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, + Map[Long, Any]((accumId, 1)), null, null)) } } } @@ -493,17 +504,16 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runEvent(ExecutorLost("exec-hostA")) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) - val noAccum = Map[Long, Any]() val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, null, null)) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, null, null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -728,6 +738,18 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assert(scheduler.sc.dagScheduler === null) } + test("accumulator not calculated for resubmitted result stage") { + //just for register + val accum = new Accumulator[Int](0, SparkContext.IntAccumulatorParam) + val finalRdd = new MyRDD(sc, 1, Nil) + submit(finalRdd, Array(0)) + completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) + completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(Accumulators.originals(accum.id).value === 1) + assertDataStructuresEmpty + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 1809b5396d53e..472191551a01f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -579,13 +579,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // single 10M result val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} - assert(thrown.getMessage().contains("bigger than maxResultSize")) + assert(thrown.getMessage().contains("bigger than spark.driver.maxResultSize")) // multiple 1M results val thrown2 = intercept[SparkException] { sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect() } - assert(thrown2.getMessage().contains("bigger than maxResultSize")) + assert(thrown2.getMessage().contains("bigger than spark.driver.maxResultSize")) } test("speculative and noPref task should be scheduled after node-local") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..e60e70afd3218 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import org.scalatest.FunSuite +import org.apache.spark.{scheduler, SparkConf, SparkContext, LocalSparkContext} +import org.apache.spark.scheduler.{TaskDescription, WorkerOffer, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} +import org.apache.mesos.SchedulerDriver +import org.apache.mesos.Protos._ +import org.scalatest.mock.EasyMockSugar +import org.apache.mesos.Protos.Value.Scalar +import org.easymock.{Capture, EasyMock} +import java.nio.ByteBuffer +import java.util.Collections +import java.util +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar { + + test("mesos resource offers result in launching tasks") { + def createOffer(id: Int, mem: Int, cpu: Int) = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() + } + + val driver = EasyMock.createMock(classOf[SchedulerDriver]) + val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + + val sc = EasyMock.createMock(classOf[SparkContext]) + EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() + EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() + EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() + EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() + EasyMock.replay(sc) + + val minMem = MemoryUtils.calculateTotalMemory(sc).toInt + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(createOffer(1, minMem, minCpu)) + mesosOffers.add(createOffer(2, minMem - 1, minCpu)) + mesosOffers.add(createOffer(3, minMem, minCpu)) + + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](2) + expectedWorkerOffers.append(new WorkerOffer( + mesosOffers.get(0).getSlaveId.getValue, + mesosOffers.get(0).getHostname, + 2 + )) + expectedWorkerOffers.append(new WorkerOffer( + mesosOffers.get(2).getSlaveId.getValue, + mesosOffers.get(2).getHostname, + 2 + )) + val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc))) + EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() + EasyMock.replay(taskScheduler) + + val capture = new Capture[util.Collection[TaskInfo]] + EasyMock.expect( + driver.launchTasks( + EasyMock.eq(Collections.singleton(mesosOffers.get(0).getId)), + EasyMock.capture(capture), + EasyMock.anyObject(classOf[Filters]) + ) + ).andReturn(Status.valueOf(1)).once + EasyMock.expect(driver.declineOffer(mesosOffers.get(1).getId)).andReturn(Status.valueOf(1)).times(1) + EasyMock.expect(driver.declineOffer(mesosOffers.get(2).getId)).andReturn(Status.valueOf(1)).times(1) + EasyMock.replay(driver) + + backend.resourceOffers(driver, mesosOffers) + + EasyMock.verify(driver) + assert(capture.getValue.size() == 1) + val taskInfo = capture.getValue.iterator().next() + assert(taskInfo.getName.equals("n1")) + val cpus = taskInfo.getResourcesList.get(0) + assert(cpus.getName.equals("cpus")) + assert(cpus.getScalar.getValue.equals(2.0)) + assert(taskInfo.getSlaveId.getValue.equals("s1")) + + // Unwanted resources offered on an existing node. Make sure they are declined + val mesosOffers2 = new java.util.ArrayList[Offer] + mesosOffers2.add(createOffer(1, minMem, minCpu)) + EasyMock.reset(taskScheduler) + EasyMock.reset(driver) + EasyMock.expect(taskScheduler.resourceOffers(EasyMock.anyObject(classOf[Seq[WorkerOffer]])).andReturn(Seq(Seq()))) + EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() + EasyMock.replay(taskScheduler) + EasyMock.expect(driver.declineOffer(mesosOffers2.get(0).getId)).andReturn(Status.valueOf(1)).times(1) + EasyMock.replay(driver) + + backend.resourceOffers(driver, mesosOffers2) + EasyMock.verify(driver) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f63e772bf1e59..c2903c8597997 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -62,7 +62,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr) + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store store @@ -263,7 +263,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, - 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr) + 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9529502bc8e10..5554efbcbadf8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -74,7 +74,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr) + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } @@ -795,7 +795,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, - new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr) + new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, + 0) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index bacf6a16fc233..d2857b8b55664 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.ui -import org.apache.spark.api.java.StorageLevels -import org.apache.spark.{SparkException, SparkConf, SparkContext} -import org.openqa.selenium.WebDriver +import scala.collection.JavaConversions._ + +import org.openqa.selenium.{By, WebDriver} import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ +import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.LocalSparkContext._ +import org.apache.spark.api.java.StorageLevels +import org.apache.spark.shuffle.FetchFailedException /** * Selenium tests for the Spark Web UI. These tests are not run by default @@ -89,7 +93,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { sc.parallelize(1 to 10).map { x => throw new Exception()}.collect() } eventually(timeout(5 seconds), interval(50 milliseconds)) { - go to sc.ui.get.appUIAddress + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") find(id("active")).get.text should be("Active Stages (0)") find(id("failed")).get.text should be("Failed Stages (1)") } @@ -101,7 +105,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { sc.parallelize(1 to 10).map { x => unserializableObject}.collect() } eventually(timeout(5 seconds), interval(50 milliseconds)) { - go to sc.ui.get.appUIAddress + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") find(id("active")).get.text should be("Active Stages (0)") // The failure occurs before the stage becomes active, hence we should still show only one // failed stage, not two: @@ -109,4 +113,191 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } } } + + test("spark.ui.killEnabled should properly control kill button display") { + def getSparkContext(killEnabled: Boolean): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + .set("spark.ui.killEnabled", killEnabled.toString) + new SparkContext(conf) + } + + def hasKillLink = find(className("kill-link")).isDefined + def runSlowJob(sc: SparkContext) { + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + } + + withSpark(getSparkContext(killEnabled = true)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") + assert(hasKillLink) + } + } + + withSpark(getSparkContext(killEnabled = false)) { sc => + runSlowJob(sc) + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") + assert(!hasKillLink) + } + } + } + + test("jobs page should not display job group name unless some job was submitted in a job group") { + withSpark(newSparkContext()) { sc => + // If no job has been run in a job group, then "(Job Group)" should not appear in the header + sc.parallelize(Seq(1, 2, 3)).count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + val tableHeaders = findAll(cssSelector("th")).map(_.text).toSeq + tableHeaders should not contain "Job Id (Job Group)" + } + // Once at least one job has been run in a job group, then we should display the group name: + sc.setJobGroup("my-job-group", "my-job-group-description") + sc.parallelize(Seq(1, 2, 3)).count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + val tableHeaders = findAll(cssSelector("th")).map(_.text).toSeq + tableHeaders should contain ("Job Id (Job Group)") + } + } + } + + test("job progress bars should handle stage / task failures") { + withSpark(newSparkContext()) { sc => + val data = sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity) + val shuffleHandle = + data.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + // Simulate fetch failures: + val mappedData = data.map { x => + val taskContext = TaskContext.get + if (taskContext.attemptId() == 1) { // Cause this stage to fail on its first attempt. + val env = SparkEnv.get + val bmAddress = env.blockManager.blockManagerId + val shuffleId = shuffleHandle.shuffleId + val mapId = 0 + val reduceId = taskContext.partitionId() + val message = "Simulated fetch failure" + throw new FetchFailedException(bmAddress, shuffleId, mapId, reduceId, message) + } else { + x + } + } + mappedData.count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + find(cssSelector(".stage-progress-cell")).get.text should be ("2/2 (1 failed)") + // Ideally, the following test would pass, but currently we overcount completed tasks + // if task recomputations occur: + // find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") + // Instead, we guarantee that the total number of tasks is always correct, while the number + // of completed tasks may be higher: + find(cssSelector(".progress-cell .progress")).get.text should be ("3/2 (1 failed)") + } + } + } + + test("job details page should display useful information for stages that haven't started") { + withSpark(newSparkContext()) { sc => + // Create a multi-stage job with a long delay in the first stage: + val rdd = sc.parallelize(Seq(1, 2, 3)).map { x => + // This long sleep call won't slow down the tests because we don't actually need to wait + // for the job to finish. + Thread.sleep(20000) + }.groupBy(identity).map(identity).groupBy(identity).map(identity) + // Start the job: + rdd.countAsync() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/?id=0") + find(id("active")).get.text should be ("Active Stages (1)") + find(id("pending")).get.text should be ("Pending Stages (2)") + // Essentially, we want to check that none of the stage rows show + // "No data available for this stage". Checking for the absence of that string is brittle + // because someone could change the error message and cause this test to pass by accident. + // Instead, it's safer to check that each row contains a link to a stage details page. + findAll(cssSelector("tbody tr")).foreach { row => + val link = row.underlying.findElement(By.xpath(".//a")) + link.getAttribute("href") should include ("stage") + } + } + } + } + + test("job progress bars / cells reflect skipped stages / tasks") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = sc.parallelize(1 to 8, 8) + .map(x => x).groupBy((x: Int) => x, numPartitions = 8) + .flatMap(x => x._2).groupBy((x: Int) => x, numPartitions = 8) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + // The completed jobs table should have two rows. The first row will be the most recent job: + val firstRow = find(cssSelector("tbody tr")).get.underlying + val firstRowColumns = firstRow.findElements(By.tagName("td")) + firstRowColumns(0).getText should be ("1") + firstRowColumns(4).getText should be ("1/1 (2 skipped)") + firstRowColumns(5).getText should be ("8/8 (16 skipped)") + // The second row is the first run of the job, where nothing was skipped: + val secondRow = findAll(cssSelector("tbody tr")).toSeq(1).underlying + val secondRowColumns = secondRow.findElements(By.tagName("td")) + secondRowColumns(0).getText should be ("0") + secondRowColumns(4).getText should be ("3/3") + secondRowColumns(5).getText should be ("24/24") + } + } + } + + test("stages that aren't run appear as 'skipped stages' after a job finishes") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/?id=1") + find(id("pending")) should be (None) + find(id("active")) should be (None) + find(id("failed")) should be (None) + find(id("completed")).get.text should be ("Completed Stages (1)") + find(id("skipped")).get.text should be ("Skipped Stages (2)") + // Essentially, we want to check that none of the stage rows show + // "No data available for this stage". Checking for the absence of that string is brittle + // because someone could change the error message and cause this test to pass by accident. + // Instead, it's safer to check that each row contains a link to a stage details page. + findAll(cssSelector("tbody tr")).foreach { row => + val link = row.underlying.findElement(By.xpath(".//a")) + link.getAttribute("href") should include ("stage") + } + } + } + } + + test("jobs with stages that are skipped should show correct link descriptions on all jobs page") { + withSpark(newSparkContext()) { sc => + // Create an RDD that involves multiple stages: + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + // Run it twice; this will cause the second job to have two "phantom" stages that were + // mentioned in its job start event but which were never actually executed: + rdd.count() + rdd.count() + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs") + findAll(cssSelector("tbody tr a")).foreach { link => + link.text.toLowerCase should include ("count") + link.text.toLowerCase should not include "unknown" + } + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 7c102cc7f4049..12af60caf7d54 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -28,32 +28,106 @@ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { - test("test LRU eviction of stages") { - val conf = new SparkConf() - conf.set("spark.ui.retainedStages", 5.toString) - val listener = new JobProgressListener(conf) - def createStageStartEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") - SparkListenerStageSubmitted(stageInfo) + private def createStageStartEvent(stageId: Int) = { + val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") + SparkListenerStageSubmitted(stageInfo) + } + + private def createStageEndEvent(stageId: Int, failed: Boolean = false) = { + val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") + if (failed) { + stageInfo.failureReason = Some("Failed!") } + SparkListenerStageCompleted(stageInfo) + } - def createStageEndEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") - SparkListenerStageCompleted(stageInfo) + private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = { + val stageInfos = stageIds.map { stageId => + new StageInfo(stageId, 0, stageId.toString, 0, null, "") } + SparkListenerJobStart(jobId, stageInfos) + } + + private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { + val result = if (failed) JobFailed(new Exception("dummy failure")) else JobSucceeded + SparkListenerJobEnd(jobId, result) + } + + private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) { + val stagesThatWontBeRun = jobId * 200 to jobId * 200 + 10 + val stageIds = jobId * 100 to jobId * 100 + 50 + listener.onJobStart(createJobStartEvent(jobId, stageIds ++ stagesThatWontBeRun)) + for (stageId <- stageIds) { + listener.onStageSubmitted(createStageStartEvent(stageId)) + listener.onStageCompleted(createStageEndEvent(stageId, failed = stageId % 2 == 0)) + } + listener.onJobEnd(createJobEndEvent(jobId, shouldFail)) + } + + private def assertActiveJobsStateIsEmpty(listener: JobProgressListener) { + listener.getSizesOfActiveStateTrackingCollections.foreach { case (fieldName, size) => + assert(size === 0, s"$fieldName was not empty") + } + } + + test("test LRU eviction of stages") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + val listener = new JobProgressListener(conf) for (i <- 1 to 50) { listener.onStageSubmitted(createStageStartEvent(i)) listener.onStageCompleted(createStageEndEvent(i)) } + assertActiveJobsStateIsEmpty(listener) listener.completedStages.size should be (5) - listener.completedStages.count(_.stageId == 50) should be (1) - listener.completedStages.count(_.stageId == 49) should be (1) - listener.completedStages.count(_.stageId == 48) should be (1) - listener.completedStages.count(_.stageId == 47) should be (1) - listener.completedStages.count(_.stageId == 46) should be (1) + listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46)) + } + + test("test LRU eviction of jobs") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + conf.set("spark.ui.retainedJobs", 5.toString) + val listener = new JobProgressListener(conf) + + // Run a bunch of jobs to get the listener into a state where we've exceeded both the + // job and stage retention limits: + for (jobId <- 1 to 10) { + runJob(listener, jobId, shouldFail = false) + } + for (jobId <- 200 to 210) { + runJob(listener, jobId, shouldFail = true) + } + assertActiveJobsStateIsEmpty(listener) + // Snapshot the sizes of various soft- and hard-size-limited collections: + val softLimitSizes = listener.getSizesOfSoftSizeLimitedCollections + val hardLimitSizes = listener.getSizesOfHardSizeLimitedCollections + // Run some more jobs: + for (jobId <- 11 to 50) { + runJob(listener, jobId, shouldFail = false) + // We shouldn't exceed the hard / soft limit sizes after the jobs have finished: + listener.getSizesOfSoftSizeLimitedCollections should be (softLimitSizes) + listener.getSizesOfHardSizeLimitedCollections should be (hardLimitSizes) + } + + listener.completedJobs.size should be (5) + listener.completedJobs.map(_.jobId).toSet should be (Set(50, 49, 48, 47, 46)) + + for (jobId <- 51 to 100) { + runJob(listener, jobId, shouldFail = true) + // We shouldn't exceed the hard / soft limit sizes after the jobs have finished: + listener.getSizesOfSoftSizeLimitedCollections should be (softLimitSizes) + listener.getSizesOfHardSizeLimitedCollections should be (hardLimitSizes) + } + assertActiveJobsStateIsEmpty(listener) + + // Completed and failed jobs each their own size limits, so this should still be the same: + listener.completedJobs.size should be (5) + listener.completedJobs.map(_.jobId).toSet should be (Set(50, 49, 48, 47, 46)) + listener.failedJobs.size should be (5) + listener.failedJobs.map(_.jobId).toSet should be (Set(100, 99, 98, 97, 96)) } test("test executor id to summary") { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 50f42054b9296..593d6dd8c3794 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.util import java.util.Properties +import org.apache.spark.shuffle.MetadataFetchFailedException + import scala.collection.Map import org.json4s.jackson.JsonMethods._ @@ -47,7 +49,12 @@ class JsonProtocolSuite extends FunSuite { val taskEndWithOutput = SparkListenerTaskEnd(1, 0, "ResultTask", Success, makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true)) - val jobStart = SparkListenerJobStart(10, Seq[Int](1, 2, 3, 4), properties) + val jobStart = { + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => + makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) + SparkListenerJobStart(10, stageInfos, properties) + } val jobEnd = SparkListenerJobEnd(20, JobSucceeded) val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]]( "JVM Information" -> Seq(("GC speed", "9999 objects/s"), ("Java home", "Land of coffee")), @@ -111,10 +118,13 @@ class JsonProtocolSuite extends FunSuite { // TaskEndReason val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, "Some exception") + val fetchMetadataFailed = new MetadataFetchFailedException(17, + 19, "metadata Fetch failed exception").toTaskEndReason val exceptionFailure = new ExceptionFailure(exception, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) + testTaskEndReason(fetchMetadataFailed) testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) @@ -224,6 +234,19 @@ class JsonProtocolSuite extends FunSuite { assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) } + test("SparkListenerJobStart backward compatibility") { + // Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property. + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) + val dummyStageInfos = + stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) + val jobStart = SparkListenerJobStart(10, stageInfos, properties) + val oldEvent = JsonProtocol.jobStartToJson(jobStart).removeField({_._1 == "Stage Infos"}) + val expectedJobStart = + SparkListenerJobStart(10, dummyStageInfos, properties) + assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -306,7 +329,7 @@ class JsonProtocolSuite extends FunSuite { case (e1: SparkListenerJobStart, e2: SparkListenerJobStart) => assert(e1.jobId === e2.jobId) assert(e1.properties === e2.properties) - assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 === i2)) + assert(e1.stageIds === e2.stageIds) case (e1: SparkListenerJobEnd, e2: SparkListenerJobEnd) => assert(e1.jobId === e2.jobId) assertEquals(e1.jobResult, e2.jobResult) @@ -413,9 +436,13 @@ class JsonProtocolSuite extends FunSuite { } private def assertEquals(bm1: BlockManagerId, bm2: BlockManagerId) { - assert(bm1.executorId === bm2.executorId) - assert(bm1.host === bm2.host) - assert(bm1.port === bm2.port) + if (bm1 == null || bm2 == null) { + assert(bm1 === bm2) + } else { + assert(bm1.executorId === bm2.executorId) + assert(bm1.host === bm2.host) + assert(bm1.port === bm2.port) + } } private def assertEquals(result1: JobResult, result2: JobResult) { @@ -1051,6 +1078,260 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobStart", | "Job ID": 10, + | "Stage Infos": [ + | { + | "Stage ID": 1, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 200, + | "RDD Info": [ + | { + | "RDD ID": 1, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 200, + | "Number of Cached Partitions": 300, + | "Memory Size": 400, + | "Tachyon Size": 0, + | "Disk Size": 500 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 2, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 400, + | "RDD Info": [ + | { + | "RDD ID": 2, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 400, + | "Number of Cached Partitions": 600, + | "Memory Size": 800, + | "Tachyon Size": 0, + | "Disk Size": 1000 + | }, + | { + | "RDD ID": 3, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 401, + | "Number of Cached Partitions": 601, + | "Memory Size": 801, + | "Tachyon Size": 0, + | "Disk Size": 1001 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 3, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 600, + | "RDD Info": [ + | { + | "RDD ID": 3, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 600, + | "Number of Cached Partitions": 900, + | "Memory Size": 1200, + | "Tachyon Size": 0, + | "Disk Size": 1500 + | }, + | { + | "RDD ID": 4, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 601, + | "Number of Cached Partitions": 901, + | "Memory Size": 1201, + | "Tachyon Size": 0, + | "Disk Size": 1501 + | }, + | { + | "RDD ID": 5, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 602, + | "Number of Cached Partitions": 902, + | "Memory Size": 1202, + | "Tachyon Size": 0, + | "Disk Size": 1502 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | { + | "Stage ID": 4, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 800, + | "RDD Info": [ + | { + | "RDD ID": 4, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 800, + | "Number of Cached Partitions": 1200, + | "Memory Size": 1600, + | "Tachyon Size": 0, + | "Disk Size": 2000 + | }, + | { + | "RDD ID": 5, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 801, + | "Number of Cached Partitions": 1201, + | "Memory Size": 1601, + | "Tachyon Size": 0, + | "Disk Size": 2001 + | }, + | { + | "RDD ID": 6, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 802, + | "Number of Cached Partitions": 1202, + | "Memory Size": 1602, + | "Tachyon Size": 0, + | "Disk Size": 2002 + | }, + | { + | "RDD ID": 7, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 803, + | "Number of Cached Partitions": 1203, + | "Memory Size": 1603, + | "Tachyon Size": 0, + | "Disk Size": 2003 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": " Accumulable 2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": " Accumulable 1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | } + | ], | "Stage IDs": [ | 1, | 2, diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index f9d1af88f3a13..0ea2d13a83505 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -118,7 +118,7 @@ class SizeEstimatorSuite // TODO: If we sample 100 elements, this should always be 4176 ? val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1)) assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000") - assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100") + assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4200") } test("32-bit arch") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index f26e40fbd4b36..3cb42d416de4f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -127,6 +127,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe test("empty partitions with spilling") { val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -152,6 +153,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe test("empty partitions with spilling, bypass merge-sort") { val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) diff --git a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala new file mode 100644 index 0000000000000..4918e2d92beb4 --- /dev/null +++ b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.sparktest + +/** + * A test suite to make sure all `implicit` functions work correctly. + * Please don't `import org.apache.spark.SparkContext._` in this class. + * + * As `implicit` is a compiler feature, we don't need to run this class. + * What we need to do is making the compiler happy. + */ +class ImplicitSuite { + + // We only want to test if `implict` works well with the compiler, so we don't need a real + // SparkContext. + def mockSparkContext[T]: org.apache.spark.SparkContext = null + + // We only want to test if `implict` works well with the compiler, so we don't need a real RDD. + def mockRDD[T]: org.apache.spark.rdd.RDD[T] = null + + def testRddToPairRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.groupByKey() + } + + def testRddToAsyncRDDActions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Int] = mockRDD + rdd.countAsync() + } + + def testRddToSequenceFileRDDFunctions(): Unit = { + // TODO eliminating `import intToIntWritable` needs refactoring SequenceFileRDDFunctions. + // That will be a breaking change. + import org.apache.spark.SparkContext.intToIntWritable + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + + def testRddToOrderedRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD + rdd.sortByKey() + } + + def testDoubleRDDToDoubleRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Double] = mockRDD + rdd.stats() + } + + def testNumericRDDToDoubleRDDFunctions(): Unit = { + val rdd: org.apache.spark.rdd.RDD[Int] = mockRDD + rdd.stats() + } + + def testDoubleAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123.4) + } + + def testIntAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123) + } + + def testLongAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123L) + } + + def testFloatAccumulatorParam(): Unit = { + val sc = mockSparkContext + sc.accumulator(123F) + } + + def testIntWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Int, Int]("/a/test/path") + } + + def testLongWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Long, Long]("/a/test/path") + } + + def testDoubleWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Double, Double]("/a/test/path") + } + + def testFloatWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Float, Float]("/a/test/path") + } + + def testBooleanWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Boolean, Boolean]("/a/test/path") + } + + def testBytesWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[Array[Byte], Array[Byte]]("/a/test/path") + } + + def testStringWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[String, String]("/a/test/path") + } + + def testWritableWritableConverter(): Unit = { + val sc = mockSparkContext + sc.sequenceFile[org.apache.hadoop.io.Text, org.apache.hadoop.io.Text]("/a/test/path") + } +} diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 16ea1a71290dc..0b7069f6e116a 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -30,71 +30,84 @@ import time import urllib2 -# Fill in release details here: -RELEASE_URL = "http://people.apache.org/~pwendell/spark-1.0.0-rc1/" -RELEASE_KEY = "9E4FE3AF" -RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/" -RELEASE_VERSION = "1.0.0" +# Note: The following variables must be set before use! +RELEASE_URL = "http://people.apache.org/~andrewor14/spark-1.1.1-rc1/" +RELEASE_KEY = "XXXXXXXX" # Your 8-digit hex +RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1033" +RELEASE_VERSION = "1.1.1" SCALA_VERSION = "2.10.4" SCALA_BINARY_VERSION = "2.10" -# +# Do not set these LOG_FILE_NAME = "spark_audit_%s" % time.strftime("%h_%m_%Y_%I_%M_%S") LOG_FILE = open(LOG_FILE_NAME, 'w') WORK_DIR = "/tmp/audit_%s" % int(time.time()) MAVEN_CMD = "mvn" GPG_CMD = "gpg" +SBT_CMD = "sbt -Dsbt.log.noformat=true" -print "Starting tests, log output in %s. Test results printed below:" % LOG_FILE_NAME - -# Track failures +# Track failures to print them at the end failures = [] +# Log a message. Use sparingly because this flushes every write. +def log(msg): + LOG_FILE.write(msg + "\n") + LOG_FILE.flush() +def log_and_print(msg): + print msg + log(msg) + +# Prompt the user to delete the scratch directory used def clean_work_files(): - print "OK to delete scratch directory '%s'? (y/N): " % WORK_DIR - response = raw_input() + response = raw_input("OK to delete scratch directory '%s'? (y/N) " % WORK_DIR) if response == "y": shutil.rmtree(WORK_DIR) - print "Should I delete the log output file '%s'? (y/N): " % LOG_FILE_NAME - response = raw_input() - if response == "y": - os.unlink(LOG_FILE_NAME) - +# Run the given command and log its output to the log file def run_cmd(cmd, exit_on_failure=True): - print >> LOG_FILE, "Running command: %s" % cmd + log("Running command: %s" % cmd) ret = subprocess.call(cmd, shell=True, stdout=LOG_FILE, stderr=LOG_FILE) if ret != 0 and exit_on_failure: - print "Command failed: %s" % cmd + log_and_print("Command failed: %s" % cmd) clean_work_files() sys.exit(-1) return ret - def run_cmd_with_output(cmd): - print >> sys.stderr, "Running command: %s" % cmd + log_and_print("Running command: %s" % cmd) return subprocess.check_output(cmd, shell=True, stderr=LOG_FILE) +# Test if the given condition is successful +# If so, print the pass message; otherwise print the failure message +def test(cond, msg): + return passed(msg) if cond else failed(msg) -def test(bool, str): - if bool: - return passed(str) - failed(str) - - -def passed(str): - print "[PASSED] %s" % str - - -def failed(str): - failures.append(str) - print "[**FAILED**] %s" % str +def passed(msg): + log_and_print("[PASSED] %s" % msg) +def failed(msg): + failures.append(msg) + log_and_print("[**FAILED**] %s" % msg) def get_url(url): return urllib2.urlopen(url).read() +# If the path exists, prompt the user to delete it +# If the resource is not deleted, abort +def ensure_path_not_present(path): + full_path = os.path.expanduser(path) + if os.path.exists(full_path): + print "Found %s locally." % full_path + response = raw_input("This can interfere with testing published artifacts. OK to delete? (y/N) ") + if response == "y": + shutil.rmtree(full_path) + else: + print "Abort." + sys.exit(-1) + +log_and_print("|-------- Starting Spark audit tests for release %s --------|" % RELEASE_VERSION) +log_and_print("Log output can be found in %s" % LOG_FILE_NAME) original_dir = os.getcwd() @@ -114,37 +127,36 @@ def get_url(url): cache_ivy_spark = "~/.ivy2/cache/org.apache.spark" local_maven_kafka = "~/.m2/repository/org/apache/kafka" local_maven_kafka = "~/.m2/repository/org/apache/spark" - - -def ensure_path_not_present(x): - if os.path.exists(os.path.expanduser(x)): - print "Please remove %s, it can interfere with testing published artifacts." % x - sys.exit(-1) - map(ensure_path_not_present, [local_ivy_spark, cache_ivy_spark, local_maven_kafka]) # SBT build tests +log_and_print("==== Building SBT modules ====") os.chdir("blank_sbt_build") os.environ["SPARK_VERSION"] = RELEASE_VERSION os.environ["SCALA_VERSION"] = SCALA_VERSION os.environ["SPARK_RELEASE_REPOSITORY"] = RELEASE_REPOSITORY os.environ["SPARK_AUDIT_MASTER"] = "local" for module in modules: + log("==== Building module %s in SBT ====" % module) os.environ["SPARK_MODULE"] = module - ret = run_cmd("sbt clean update", exit_on_failure=False) - test(ret == 0, "sbt build against '%s' module" % module) + ret = run_cmd("%s clean update" % SBT_CMD, exit_on_failure=False) + test(ret == 0, "SBT build against '%s' module" % module) os.chdir(original_dir) # SBT application tests +log_and_print("==== Building SBT applications ====") for app in ["sbt_app_core", "sbt_app_graphx", "sbt_app_streaming", "sbt_app_sql", "sbt_app_hive", "sbt_app_kinesis"]: + log("==== Building application %s in SBT ====" % app) os.chdir(app) - ret = run_cmd("sbt clean run", exit_on_failure=False) - test(ret == 0, "sbt application (%s)" % app) + ret = run_cmd("%s clean run" % SBT_CMD, exit_on_failure=False) + test(ret == 0, "SBT application (%s)" % app) os.chdir(original_dir) # Maven build tests os.chdir("blank_maven_build") +log_and_print("==== Building Maven modules ====") for module in modules: + log("==== Building module %s in maven ====" % module) cmd = ('%s --update-snapshots -Dspark.release.repository="%s" -Dspark.version="%s" ' '-Dspark.module="%s" clean compile' % (MAVEN_CMD, RELEASE_REPOSITORY, RELEASE_VERSION, module)) @@ -152,6 +164,8 @@ def ensure_path_not_present(x): test(ret == 0, "maven build against '%s' module" % module) os.chdir(original_dir) +# Maven application tests +log_and_print("==== Building Maven applications ====") os.chdir("maven_app_core") mvn_exec_cmd = ('%s --update-snapshots -Dspark.release.repository="%s" -Dspark.version="%s" ' '-Dscala.binary.version="%s" clean compile ' @@ -172,15 +186,14 @@ def ensure_path_not_present(x): artifact_regex = r = re.compile("") artifacts = r.findall(index_page) +# Verify artifact integrity for artifact in artifacts: - print "==== Verifying download integrity for artifact: %s ====" % artifact + log_and_print("==== Verifying download integrity for artifact: %s ====" % artifact) artifact_url = "%s/%s" % (RELEASE_URL, artifact) - run_cmd("wget %s" % artifact_url) - key_file = "%s.asc" % artifact + run_cmd("wget %s" % artifact_url) run_cmd("wget %s/%s" % (RELEASE_URL, key_file)) - run_cmd("wget %s%s" % (artifact_url, ".sha")) # Verify signature @@ -208,31 +221,17 @@ def ensure_path_not_present(x): os.chdir(WORK_DIR) -for artifact in artifacts: - print "==== Verifying build and tests for artifact: %s ====" % artifact - os.chdir(os.path.join(WORK_DIR, dir_name)) - - os.environ["MAVEN_OPTS"] = "-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - # Verify build - print "==> Running build" - run_cmd("sbt assembly") - passed("sbt build successful") - run_cmd("%s package -DskipTests" % MAVEN_CMD) - passed("Maven build successful") - - # Verify tests - print "==> Performing unit tests" - run_cmd("%s test" % MAVEN_CMD) - passed("Tests successful") - os.chdir(WORK_DIR) - -clean_work_files() - +# Report result +log_and_print("\n") if len(failures) == 0: - print "ALL TESTS PASSED" + log_and_print("*** ALL TESTS PASSED ***") else: - print "SOME TESTS DID NOT PASS" + log_and_print("XXXXX SOME TESTS DID NOT PASS XXXXX") for f in failures: - print f - + log_and_print(" %s" % f) os.chdir(original_dir) + +# Clean up +clean_work_files() + +log_and_print("|-------- Spark release audit complete --------|") diff --git a/dev/audit-release/blank_sbt_build/build.sbt b/dev/audit-release/blank_sbt_build/build.sbt index 696c7f651837c..62815542e5bd9 100644 --- a/dev/audit-release/blank_sbt_build/build.sbt +++ b/dev/audit-release/blank_sbt_build/build.sbt @@ -19,10 +19,12 @@ name := "Spark Release Auditor" version := "1.0" -scalaVersion := "2.9.3" +scalaVersion := System.getenv.get("SCALA_VERSION") libraryDependencies += "org.apache.spark" % System.getenv.get("SPARK_MODULE") % System.getenv.get("SPARK_VERSION") resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), + "Eclipse Paho Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/", + "Maven Repository" at "http://repo1.maven.org/maven2/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_hive/build.sbt b/dev/audit-release/sbt_app_hive/build.sbt index a0d4f25da5842..c8824f2b15e55 100644 --- a/dev/audit-release/sbt_app_hive/build.sbt +++ b/dev/audit-release/sbt_app_hive/build.sbt @@ -25,4 +25,5 @@ libraryDependencies += "org.apache.spark" %% "spark-hive" % System.getenv.get("S resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), + "Maven Repository" at "http://repo1.maven.org/maven2/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_hive/src/main/resources/hive-site.xml b/dev/audit-release/sbt_app_hive/src/main/resources/hive-site.xml deleted file mode 100644 index 93b835813d535..0000000000000 --- a/dev/audit-release/sbt_app_hive/src/main/resources/hive-site.xml +++ /dev/null @@ -1,213 +0,0 @@ - - - - - - - - - - - - - - - - - - build.dir - ${user.dir}/build - - - - build.dir.hive - ${build.dir}/hive - - - - hadoop.tmp.dir - ${build.dir.hive}/test/hadoop-${user.name} - A base for other temporary directories. - - - - - - hive.exec.scratchdir - ${build.dir}/scratchdir - Scratch space for Hive jobs - - - - hive.exec.local.scratchdir - ${build.dir}/localscratchdir/ - Local scratch space for Hive jobs - - - - javax.jdo.option.ConnectionURL - - jdbc:derby:;databaseName=../build/test/junit_metastore_db;create=true - - - - javax.jdo.option.ConnectionDriverName - org.apache.derby.jdbc.EmbeddedDriver - - - - javax.jdo.option.ConnectionUserName - APP - - - - javax.jdo.option.ConnectionPassword - mine - - - - - hive.metastore.warehouse.dir - ${test.warehouse.dir} - - - - - hive.metastore.metadb.dir - ${build.dir}/test/data/metadb/ - - Required by metastore server or if the uris argument below is not supplied - - - - - test.log.dir - ${build.dir}/test/logs - - - - - test.src.dir - ${build.dir}/src/test - - - - - - - hive.jar.path - ${build.dir.hive}/ql/hive-exec-${version}.jar - - - - - hive.metastore.rawstore.impl - org.apache.hadoop.hive.metastore.ObjectStore - Name of the class that implements org.apache.hadoop.hive.metastore.rawstore interface. This class is used to store and retrieval of raw metadata objects such as table, database - - - - hive.querylog.location - ${build.dir}/tmp - Location of the structured hive logs - - - - - - hive.task.progress - false - Track progress of a task - - - - hive.support.concurrency - false - Whether hive supports concurrency or not. A zookeeper instance must be up and running for the default hive lock manager to support read-write locks. - - - - fs.pfile.impl - org.apache.hadoop.fs.ProxyLocalFileSystem - A proxy for local file system used for cross file system testing - - - - hive.exec.mode.local.auto - false - - Let hive determine whether to run in local mode automatically - Disabling this for tests so that minimr is not affected - - - - - hive.auto.convert.join - false - Whether Hive enable the optimization about converting common join into mapjoin based on the input file size - - - - hive.ignore.mapjoin.hint - false - Whether Hive ignores the mapjoin hint - - - - hive.input.format - org.apache.hadoop.hive.ql.io.CombineHiveInputFormat - The default input format, if it is not specified, the system assigns it. It is set to HiveInputFormat for hadoop versions 17, 18 and 19, whereas it is set to CombineHiveInputFormat for hadoop 20. The user can always overwrite it - if there is a bug in CombineHiveInputFormat, it can always be manually set to HiveInputFormat. - - - - hive.default.rcfile.serde - org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe - The default SerDe hive will use for the rcfile format - - - diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh new file mode 100755 index 0000000000000..7473c20d28e09 --- /dev/null +++ b/dev/change-version-to-2.10.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +find . -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {} diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh new file mode 100755 index 0000000000000..3957a9f3ba258 --- /dev/null +++ b/dev/change-version-to-2.11.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +find . -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.10|\1_2.11|g' {} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 281e8d4de6d71..e0aca467ac949 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -27,13 +27,20 @@ # Would be nice to add: # - Send output to stderr and have useful logging in stdout -GIT_USERNAME=${GIT_USERNAME:-pwendell} -GIT_PASSWORD=${GIT_PASSWORD:-XXX} +# Note: The following variables must be set before use! +ASF_USERNAME=${ASF_USERNAME:-pwendell} +ASF_PASSWORD=${ASF_PASSWORD:-XXX} GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} GIT_BRANCH=${GIT_BRANCH:-branch-1.0} -RELEASE_VERSION=${RELEASE_VERSION:-1.0.0} +RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} +NEXT_VERSION=${NEXT_VERSION:-1.2.1} RC_NAME=${RC_NAME:-rc2} -USER_NAME=${USER_NAME:-pwendell} + +M2_REPO=~/.m2/repository +SPARK_REPO=$M2_REPO/org/apache/spark +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_UPLOAD=$NEXUS_ROOT/deploy/maven2 +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads if [ -z "$JAVA_HOME" ]; then echo "Error: JAVA_HOME is not set, cannot proceed." @@ -46,31 +53,90 @@ set -e GIT_TAG=v$RELEASE_VERSION-$RC_NAME if [[ ! "$@" =~ --package-only ]]; then - echo "Creating and publishing release" + echo "Creating release commit and publishing to Apache repository" # Artifact publishing - git clone https://git-wip-us.apache.org/repos/asf/spark.git -b $GIT_BRANCH - cd spark + git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ + -b $GIT_BRANCH + pushd spark export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - mvn -Pyarn release:clean - - mvn -DskipTests \ - -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ - -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ - -Dmaven.javadoc.skip=true \ - -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - --batch-mode release:prepare - - mvn -DskipTests \ - -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ - -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Dmaven.javadoc.skip=true \ + # Create release commits and push them to github + # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build + # or before we coin the release commit. This helps avoid races where + # other people add commits to this branch while we are in the middle of building. + old=" ${RELEASE_VERSION}-SNAPSHOT<\/version>" + new=" ${RELEASE_VERSION}<\/version>" + find . -name pom.xml -o -name package.scala | grep -v dev | xargs -I {} sed -i \ + -e "s/$old/$new/" {} + git commit -a -m "Preparing Spark release $GIT_TAG" + echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" + git tag $GIT_TAG + + old=" ${RELEASE_VERSION}<\/version>" + new=" ${NEXT_VERSION}-SNAPSHOT<\/version>" + find . -name pom.xml -o -name package.scala | grep -v dev | xargs -I {} sed -i \ + -e "s/$old/$new/" {} + git commit -a -m "Preparing development version ${NEXT_VERSION}-SNAPSHOT" + git push origin $GIT_TAG + git push origin HEAD:$GIT_BRANCH + git checkout -f $GIT_TAG + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $GIT_TAG" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + rm -rf $SPARK_REPO + + mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - release:perform + clean install - cd .. + ./dev/change-version-to-2.11.sh + + mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ + -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + clean install + + ./dev/change-version-to-2.10.sh + + pushd $SPARK_REPO + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; + gpg --print-md MD5 $file > $file.md5; + gpg --print-md SHA1 $file > $file.sha1 + done + + echo "Uplading files to $NEXUS_UPLOAD" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$NEXUS_UPLOAD/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $GIT_TAG" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + + popd + popd rm -rf spark fi @@ -101,7 +167,13 @@ make_binary_release() { cp -r spark spark-$RELEASE_VERSION-bin-$NAME cd spark-$RELEASE_VERSION-bin-$NAME - ./make-distribution.sh --name $NAME --tgz $FLAGS + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-version-to-2.11.sh + fi + + ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log cd .. cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . rm -rf spark-$RELEASE_VERSION-bin-$NAME @@ -117,22 +189,24 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" & -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & -make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" & -make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" & + +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" & +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & +make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" & +make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" & +make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" & +make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" & make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & -make_binary_release "mapr3" "-Pmapr3 -Phive" & -make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" & wait # Copy data echo "Copying release tarballs" rc_folder=spark-$RELEASE_VERSION-$RC_NAME -ssh $USER_NAME@people.apache.org \ - mkdir /home/$USER_NAME/public_html/$rc_folder +ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_folder scp spark-* \ - $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_folder/ + $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ # Docs cd spark @@ -142,12 +216,12 @@ cd docs JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build echo "Copying release documentation" rc_docs_folder=${rc_folder}-docs -ssh $USER_NAME@people.apache.org \ - mkdir /home/$USER_NAME/public_html/$rc_docs_folder -rsync -r _site/* $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_docs_folder +ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder +rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder echo "Release $RELEASE_VERSION completed:" echo "Git tag:\t $GIT_TAG" echo "Release commit:\t $release_hash" -echo "Binary location:\t http://people.apache.org/~$USER_NAME/$rc_folder" -echo "Doc location:\t http://people.apache.org/~$USER_NAME/$rc_docs_folder" +echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" +echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py new file mode 100755 index 0000000000000..f4bf734081583 --- /dev/null +++ b/dev/create-release/generate-contributors.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This script automates the process of creating release notes. + +import os +import re +import sys + +from releaseutils import * + +# You must set the following before use! +JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira") +START_COMMIT = os.environ.get("START_COMMIT", "37b100") +END_COMMIT = os.environ.get("END_COMMIT", "3693ae") + +try: + from jira.client import JIRA +except ImportError: + print "This tool requires the jira-python library" + print "Install using 'sudo pip install jira-python'" + sys.exit(-1) + +try: + import unidecode +except ImportError: + print "This tool requires the unidecode library to decode obscure github usernames" + print "Install using 'sudo pip install unidecode'" + sys.exit(-1) + +# If commit range is not specified, prompt the user to provide it +if not START_COMMIT or not END_COMMIT: + print "A commit range is required to proceed." + if not START_COMMIT: + START_COMMIT = raw_input("Please specify starting commit hash (inclusive): ") + if not END_COMMIT: + END_COMMIT = raw_input("Please specify ending commit hash (non-inclusive): ") + +# Verify provided arguments +start_commit_line = get_one_line(START_COMMIT) +end_commit_line = get_one_line(END_COMMIT) +num_commits = num_commits_in_range(START_COMMIT, END_COMMIT) +if not start_commit_line: sys.exit("Start commit %s not found!" % START_COMMIT) +if not end_commit_line: sys.exit("End commit %s not found!" % END_COMMIT) +if num_commits == 0: + sys.exit("There are no commits in the provided range [%s, %s)" % (START_COMMIT, END_COMMIT)) +print "\n==================================================================================" +print "JIRA server: %s" % JIRA_API_BASE +print "Start commit (inclusive): %s" % start_commit_line +print "End commit (non-inclusive): %s" % end_commit_line +print "Number of commits in this range: %s" % num_commits +print +response = raw_input("Is this correct? [Y/n] ") +if response.lower() != "y" and response: + sys.exit("Ok, exiting") +print "==================================================================================\n" + +# Find all commits within this range +print "Gathering commits within range [%s..%s)" % (START_COMMIT, END_COMMIT) +commits = get_one_line_commits(START_COMMIT, END_COMMIT) +if not commits: sys.exit("Error: No commits found within this range!") +commits = commits.split("\n") + +# Filter out special commits +releases = [] +reverts = [] +nojiras = [] +filtered_commits = [] +def is_release(commit): + return re.findall("\[release\]", commit.lower()) or\ + "maven-release-plugin" in commit or "CHANGES.txt" in commit +def has_no_jira(commit): + return not re.findall("SPARK-[0-9]+", commit.upper()) +def is_revert(commit): + return "revert" in commit.lower() +def is_docs(commit): + return re.findall("docs*", commit.lower()) or "programming guide" in commit.lower() +for c in commits: + if not c: continue + elif is_release(c): releases.append(c) + elif is_revert(c): reverts.append(c) + elif is_docs(c): filtered_commits.append(c) # docs may not have JIRA numbers + elif has_no_jira(c): nojiras.append(c) + else: filtered_commits.append(c) + +# Warn against ignored commits +def print_indented(_list): + for x in _list: print " %s" % x +if releases or reverts or nojiras: + print "\n==================================================================================" + if releases: print "Releases (%d)" % len(releases); print_indented(releases) + if reverts: print "Reverts (%d)" % len(reverts); print_indented(reverts) + if nojiras: print "No JIRA (%d)" % len(nojiras); print_indented(nojiras) + print "==================== Warning: the above commits will be ignored ==================\n" +response = raw_input("%d commits left to process. Ok to proceed? [y/N] " % len(filtered_commits)) +if response.lower() != "y": + sys.exit("Ok, exiting.") + +# Keep track of warnings to tell the user at the end +warnings = [] + +# Populate a map that groups issues and components by author +# It takes the form: Author name -> { Contribution type -> Spark components } +# For instance, +# { +# 'Andrew Or': { +# 'bug fixes': ['windows', 'core', 'web ui'], +# 'improvements': ['core'] +# }, +# 'Tathagata Das' : { +# 'bug fixes': ['streaming'] +# 'new feature': ['streaming'] +# } +# } +# +author_info = {} +jira_options = { "server": JIRA_API_BASE } +jira = JIRA(jira_options) +print "\n=========================== Compiling contributor list ===========================" +for commit in filtered_commits: + commit_hash = re.findall("^[a-z0-9]+", commit)[0] + issues = re.findall("SPARK-[0-9]+", commit.upper()) + author = get_author(commit_hash) + author = unidecode.unidecode(unicode(author, "UTF-8")) # guard against special characters + date = get_date(commit_hash) + # Parse components from the commit message, if any + commit_components = find_components(commit, commit_hash) + # Populate or merge an issue into author_info[author] + def populate(issue_type, components): + components = components or [CORE_COMPONENT] # assume core if no components provided + if author not in author_info: + author_info[author] = {} + if issue_type not in author_info[author]: + author_info[author][issue_type] = set() + for component in all_components: + author_info[author][issue_type].add(component) + # Find issues and components associated with this commit + for issue in issues: + jira_issue = jira.issue(issue) + jira_type = jira_issue.fields.issuetype.name + jira_type = translate_issue_type(jira_type, issue, warnings) + jira_components = [translate_component(c.name, commit_hash, warnings)\ + for c in jira_issue.fields.components] + all_components = set(jira_components + commit_components) + populate(jira_type, all_components) + # For docs without an associated JIRA, manually add it ourselves + if is_docs(commit) and not issues: + populate("documentation", commit_components) + print " Processed commit %s authored by %s on %s" % (commit_hash, author, date) +print "==================================================================================\n" + +# Write to contributors file ordered by author names +# Each line takes the format "Author name - semi-colon delimited contributions" +# e.g. Andrew Or - Bug fixes in Windows, Core, and Web UI; improvements in Core +# e.g. Tathagata Das - Bug fixes and new features in Streaming +contributors_file_name = "contributors.txt" +contributors_file = open(contributors_file_name, "w") +authors = author_info.keys() +authors.sort() +for author in authors: + contribution = "" + components = set() + issue_types = set() + for issue_type, comps in author_info[author].items(): + components.update(comps) + issue_types.add(issue_type) + # If there is only one component, mention it only once + # e.g. Bug fixes, improvements in MLlib + if len(components) == 1: + contribution = "%s in %s" % (nice_join(issue_types), next(iter(components))) + # Otherwise, group contributions by issue types instead of modules + # e.g. Bug fixes in MLlib, Core, and Streaming; documentation in YARN + else: + contributions = ["%s in %s" % (issue_type, nice_join(comps)) \ + for issue_type, comps in author_info[author].items()] + contribution = "; ".join(contributions) + # Do not use python's capitalize() on the whole string to preserve case + assert contribution + contribution = contribution[0].capitalize() + contribution[1:] + line = "%s - %s" % (author, contribution) + contributors_file.write(line + "\n") +contributors_file.close() +print "Contributors list is successfully written to %s!" % contributors_file_name + +# Log any warnings encountered in the process +if warnings: + print "\n============ Warnings encountered while creating the contributor list ============" + for w in warnings: print w + print "Please correct these in the final contributors list at %s." % contributors_file_name + print "==================================================================================\n" + diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py new file mode 100755 index 0000000000000..e56d7fa58fa2c --- /dev/null +++ b/dev/create-release/releaseutils.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file contains helper methods used in creating a release. + +import re +from subprocess import Popen, PIPE + +# Utility functions run git commands (written with Git 1.8.5) +def run_cmd(cmd): return Popen(cmd, stdout=PIPE).communicate()[0] +def get_author(commit_hash): + return run_cmd(["git", "show", "--quiet", "--pretty=format:%an", commit_hash]) +def get_date(commit_hash): + return run_cmd(["git", "show", "--quiet", "--pretty=format:%cd", commit_hash]) +def get_one_line(commit_hash): + return run_cmd(["git", "show", "--quiet", "--pretty=format:\"%h %cd %s\"", commit_hash]) +def get_one_line_commits(start_hash, end_hash): + return run_cmd(["git", "log", "--oneline", "%s..%s" % (start_hash, end_hash)]) +def num_commits_in_range(start_hash, end_hash): + output = run_cmd(["git", "log", "--oneline", "%s..%s" % (start_hash, end_hash)]) + lines = [line for line in output.split("\n") if line] # filter out empty lines + return len(lines) + +# Maintain a mapping for translating issue types to contributions in the release notes +# This serves an additional function of warning the user against unknown issue types +# Note: This list is partially derived from this link: +# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/issuetypes +# Keep these in lower case +known_issue_types = { + "bug": "bug fixes", + "build": "build fixes", + "improvement": "improvements", + "new feature": "new features", + "documentation": "documentation" +} + +# Maintain a mapping for translating component names when creating the release notes +# This serves an additional function of warning the user against unknown components +# Note: This list is largely derived from this link: +# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/components +CORE_COMPONENT = "Core" +known_components = { + "block manager": CORE_COMPONENT, + "build": CORE_COMPONENT, + "deploy": CORE_COMPONENT, + "documentation": CORE_COMPONENT, + "ec2": "EC2", + "examples": CORE_COMPONENT, + "graphx": "GraphX", + "input/output": CORE_COMPONENT, + "java api": "Java API", + "mesos": "Mesos", + "ml": "MLlib", + "mllib": "MLlib", + "project infra": "Project Infra", + "pyspark": "PySpark", + "shuffle": "Shuffle", + "spark core": CORE_COMPONENT, + "spark shell": CORE_COMPONENT, + "sql": "SQL", + "streaming": "Streaming", + "web ui": "Web UI", + "windows": "Windows", + "yarn": "YARN" +} + +# Translate issue types using a format appropriate for writing contributions +# If an unknown issue type is encountered, warn the user +def translate_issue_type(issue_type, issue_id, warnings): + issue_type = issue_type.lower() + if issue_type in known_issue_types: + return known_issue_types[issue_type] + else: + warnings.append("Unknown issue type \"%s\" (see %s)" % (issue_type, issue_id)) + return issue_type + +# Translate component names using a format appropriate for writing contributions +# If an unknown component is encountered, warn the user +def translate_component(component, commit_hash, warnings): + component = component.lower() + if component in known_components: + return known_components[component] + else: + warnings.append("Unknown component \"%s\" (see %s)" % (component, commit_hash)) + return component + +# Parse components in the commit message +# The returned components are already filtered and translated +def find_components(commit, commit_hash): + components = re.findall("\[\w*\]", commit.lower()) + components = [translate_component(c, commit_hash)\ + for c in components if c in known_components] + return components + +# Join a list of strings in a human-readable manner +# e.g. ["Juice"] -> "Juice" +# e.g. ["Juice", "baby"] -> "Juice and baby" +# e.g. ["Juice", "baby", "moon"] -> "Juice, baby, and moon" +def nice_join(str_list): + str_list = list(str_list) # sometimes it's a set + if not str_list: + return "" + elif len(str_list) == 1: + return next(iter(str_list)) + elif len(str_list) == 2: + return " and ".join(str_list) + else: + return ", ".join(str_list[:-1]) + ", and " + str_list[-1] + diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 02ac20984add9..dfa924d2aa0ba 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -214,15 +214,10 @@ def fix_version_from_branch(branch, versions): return filter(lambda x: x.name.startswith(branch_ver), versions)[-1] -def resolve_jira(title, merge_branches, comment): +def resolve_jira_issue(merge_branches, comment, default_jira_id=""): asf_jira = jira.client.JIRA({'server': JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) - default_jira_id = "" - search = re.findall("SPARK-[0-9]{4,5}", title) - if len(search) > 0: - default_jira_id = search[0] - jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) if jira_id == "": jira_id = default_jira_id @@ -280,6 +275,15 @@ def get_version_json(version_str): print "Succesfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) +def resolve_jira_issues(title, merge_branches, comment): + jira_ids = re.findall("SPARK-[0-9]{4,5}", title) + + if len(jira_ids) == 0: + resolve_jira_issue(merge_branches, comment) + for jira_id in jira_ids: + resolve_jira_issue(merge_branches, comment, jira_id) + + branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) # Assumes branch names can be sorted lexicographically @@ -338,7 +342,7 @@ def get_version_json(version_str): if JIRA_USERNAME and JIRA_PASSWORD: continue_maybe("Would you like to update an associated JIRA?") jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) - resolve_jira(title, merged_refs, jira_comment) + resolve_jira_issues(title, merged_refs, jira_comment) else: print "JIRA_USERNAME and JIRA_PASSWORD not set" print "Exiting without trying to close the associated JIRA." diff --git a/dev/run-tests b/dev/run-tests index de607e4344453..328a73bd8b26d 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -139,9 +139,6 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_BUILD { - # We always build with Hive because the PySpark Spark SQL tests need it. - BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-0.12.0" - # NOTE: echo "q" is needed because sbt on encountering a build file with failure #+ (either resolution or compilation) prompts the user for input either q, r, etc @@ -151,15 +148,17 @@ CURRENT_BLOCK=$BLOCK_BUILD # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? # First build with 0.12 to ensure patches do not break the hive 12 build + HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0" echo "[info] Compile with hive 0.12" echo -e "q\n" \ - | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean hive/compile hive-thriftserver/compile \ + | sbt/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" # Then build with default version(0.13.1) because tests are based on this version - echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS -Phive" + echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\ + " -Phive -Phive-thriftserver" echo -e "q\n" \ - | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive package assembly/assembly \ + | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } @@ -174,7 +173,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. # This must be a single argument, as it is. if [ -n "$_RUN_SQL_TESTS" ]; then - SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" fi if [ -n "$_SQL_TESTS_ONLY" ]; then diff --git a/dev/scalastyle b/dev/scalastyle index ed1b6b730af6e..c3c6012e74ffa 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt diff --git a/docs/README.md b/docs/README.md index d2d58e435d4c4..119484038083f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -43,7 +43,7 @@ You can modify the default Jekyll build as follows: ## Pygments We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages, -so you will also need to install that (it requires Python) by running `sudo easy_install Pygments`. +so you will also need to install that (it requires Python) by running `sudo pip install Pygments`. To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile phase, use the following sytax: @@ -53,6 +53,11 @@ phase, use the following sytax: // supported languages too. {% endhighlight %} +## Sphinx + +We use Sphinx to generate Python API docs, so you will need to install it by running +`sudo pip install sphinx`. + ## API Docs (Scaladoc and Sphinx) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. diff --git a/docs/_config.yml b/docs/_config.yml index cdea02fcffbc5..a96a76dd9ab5e 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -13,8 +13,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.2.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.2.0 +SPARK_VERSION: 1.3.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.3.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.18.1 diff --git a/docs/building-spark.md b/docs/building-spark.md index 238ddae15545e..6cca2da8e86d2 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -92,8 +92,11 @@ mvn -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package # Apache Hadoop 2.3.X mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package -# Apache Hadoop 2.4.X -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package +# Apache Hadoop 2.4.X or 2.5.X +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package + +Versions of Hadoop after 2.5.X may or may not work with the -Phadoop-2.4 profile (they were +released after this version of Spark). # Different versions of HDFS and YARN. mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package @@ -101,25 +104,35 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, -add the `-Phive` profile to your existing build options. By default Spark -will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using -the `-Phive-0.12.0` profile. +add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. +By default Spark will build with Hive 0.13.1 bindings. You can also build for +Hive 0.12.0 using the `-Phive-0.12.0` profile. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package # Apache Hadoop 2.4.X with Hive 12 support -mvn -Pyarn -Phive-0.12.0 -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-0.12.0 -Phive-thriftserver -DskipTests clean package {% endhighlight %} +# Building for Scala 2.11 +To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: + + dev/change-version-to-2.11.sh + mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package + +Scala 2.11 support in Spark is experimental and does not support a few features. +Specifically, Spark's external Kafka library and JDBC component are not yet +supported in Scala 2.11 builds. + # Spark Tests in Maven Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package - mvn -Pyarn -Phadoop-2.3 -Phive test + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-thriftserver clean package + mvn -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test The ScalaTest plugin also supports running only a specific test suite as follows: @@ -182,16 +195,16 @@ can be set to control the SBT build. For example: Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly - sbt/sbt -Pyarn -Phadoop-2.3 -Phive test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test To run only a specific test suite as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite" To run test suites of a specific sub project as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test # Speeding up Compilation with Zinc diff --git a/docs/configuration.md b/docs/configuration.md index f0b396e21f198..0b77f5ab645c9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -52,7 +52,7 @@ Then, you can supply configuration values at runtime: --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} -The Spark shell and [`spark-submit`](cluster-overview.html#launching-applications-with-spark-submit) +The Spark shell and [`spark-submit`](submitting-applications.html) tool support two ways to load configurations dynamically. The first are command line options, such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` flag, but uses special flags for properties that play a part in launching the Spark application. @@ -224,6 +224,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in Executors. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. + (Currently, this setting does not work for YARN, see SPARK-2996 for more details). diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 530798f2b8022..66bf5f1a855ed 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -12,16 +12,14 @@ on the [Amazon Web Services site](http://aws.amazon.com/). `spark-ec2` is designed to manage multiple named clusters. You can launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster -launches a set of instances, which are tagged with the cluster name, -and placed into EC2 security groups. If you don't specify a security -group, the `spark-ec2` script will create security groups based on the -cluster name you request. For example, a cluster named +shutdown an existing cluster, or log into a cluster. Each cluster is +identified by placing its machines into EC2 security groups whose names +are derived from the name of the cluster. For example, a cluster named `test` will contain a master node in a security group called `test-master`, and a number of slave nodes in a security group called -`test-slaves`. You can also specify a security group prefix to be used -in place of the cluster name. Machines in a cluster can be identified -by looking for the "Name" tag of the instance in the Amazon EC2 Console. +`test-slaves`. The `spark-ec2` script will create these security groups +for you based on the cluster name you request. You can also use them to +identify machines belonging to each cluster in the Amazon EC2 Console. # Before You Start diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index fdb9f98e214e5..e298c51f8a5b7 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -6,6 +6,47 @@ title: GraphX Programming Guide * This will become a table of contents (this text will be scraped). {:toc} + + +[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +[Edge]: api/scala/index.html#org.apache.spark.graphx.Edge +[EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet +[Graph]: api/scala/index.html#org.apache.spark.graphx.Graph +[GraphOps]: api/scala/index.html#org.apache.spark.graphx.GraphOps +[Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED] +[Graph.reverse]: api/scala/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED] +[Graph.subgraph]: api/scala/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexId,VD)⇒Boolean):Graph[VD,ED] +[Graph.mask]: api/scala/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED] +[Graph.groupEdges]: api/scala/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED] +[GraphOps.joinVertices]: api/scala/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexId,U)])((VertexId,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED] +[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] +[Graph.aggregateMessages]: api/scala/index.html#org.apache.spark.graphx.Graph@aggregateMessages[A]((EdgeContext[VD,ED,A])⇒Unit,(A,A)⇒A,TripletFields)(ClassTag[A]):VertexRDD[A] +[EdgeContext]: api/scala/index.html#org.apache.spark.graphx.EdgeContext +[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A] +[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]] +[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]] +[RDD Persistence]: programming-guide.html#rdd-persistence +[Graph.cache]: api/scala/index.html#org.apache.spark.graphx.Graph@cache():Graph[VD,ED] +[GraphOps.pregel]: api/scala/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexId,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexId,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED] +[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy$ +[GraphLoader.edgeListFile]: api/scala/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int] +[Graph.apply]: api/scala/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexId,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] +[Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] +[Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] +[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy +[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED] +[PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ +[ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ +[TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ +[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED] +[EdgeContext.sendToSrc]: api/scala/index.html#org.apache.spark.graphx.EdgeContext@sendToSrc(msg:A):Unit +[EdgeContext.sendToDst]: api/scala/index.html#org.apache.spark.graphx.EdgeContext@sendToDst(msg:A):Unit +[TripletFields]: api/java/org/apache/spark/graphx/TripletFields.html +[TripletFields.All]: api/java/org/apache/spark/graphx/TripletFields.html#All +[TripletFields.None]: api/java/org/apache/spark/graphx/TripletFields.html#None +[TripletFields.Src]: api/java/org/apache/spark/graphx/TripletFields.html#Src +[TripletFields.Dst]: api/java/org/apache/spark/graphx/TripletFields.html#Dst +

    - Data-Parallel vs. Graph-Parallel - -

    +1. To improve performance we have introduced a new version of +[`mapReduceTriplets`][Graph.mapReduceTriplets] called +[`aggregateMessages`][Graph.aggregateMessages] which takes the messages previously returned from +[`mapReduceTriplets`][Graph.mapReduceTriplets] through a callback ([`EdgeContext`][EdgeContext]) +rather than by return value. +We are deprecating [`mapReduceTriplets`][Graph.mapReduceTriplets] and encourage users to consult +the [transition guide](#mrTripletsTransition). -However, the same restrictions that enable these substantial performance gains also make it -difficult to express many of the important stages in a typical graph-analytics pipeline: -constructing the graph, modifying its structure, or expressing computation that spans multiple -graphs. Furthermore, how we look at data depends on our objectives and the same raw data may have -many different table and graph views. - -

    - Tables and Graphs - -

    - -As a consequence, it is often necessary to be able to move between table and graph views of the same -physical data and to leverage the properties of each view to easily and efficiently express -computation. However, existing graph analytics pipelines must compose graph-parallel and data- -parallel systems, leading to extensive data movement and duplication and a complicated programming -model. - -

    - Graph Analytics Pipeline - -

    - -The goal of the GraphX project is to unify graph-parallel and data-parallel computation in one -system with a single composable API. The GraphX API enables users to view data both as a graph and -as collections (i.e., RDDs) without data movement or duplication. By incorporating recent advances -in graph-parallel systems, GraphX is able to optimize the execution of graph operations. - -## GraphX Replaces the Spark Bagel API - -Prior to the release of GraphX, graph computation in Spark was expressed using Bagel, an -implementation of Pregel. GraphX improves upon Bagel by exposing a richer property graph API, a -more streamlined version of the Pregel abstraction, and system optimizations to improve performance -and reduce memory overhead. While we plan to eventually deprecate Bagel, we will continue to -support the [Bagel API](api/scala/index.html#org.apache.spark.bagel.package) and -[Bagel programming guide](bagel-programming-guide.html). However, we encourage Bagel users to -explore the new GraphX API and comment on issues that may complicate the transition from Bagel. - -## Migrating from Spark 0.9.1 - -GraphX in Spark {{site.SPARK_VERSION}} contains one user-facing interface change from Spark 0.9.1. [`EdgeRDD`][EdgeRDD] may now store adjacent vertex attributes to construct the triplets, so it has gained a type parameter. The edges of a graph of type `Graph[VD, ED]` are of type `EdgeRDD[ED, VD]` rather than `EdgeRDD[ED]`. - -[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +2. In Spark 1.0 and 1.1, the type signature of [`EdgeRDD`][EdgeRDD] switched from +`EdgeRDD[ED]` to `EdgeRDD[ED, VD]` to enable some caching optimizations. We have since discovered +a more elegant solution and have restored the type signature to the more natural `EdgeRDD[ED]` type. # Getting Started @@ -108,9 +96,10 @@ import org.apache.spark.rdd.RDD If you are not using the Spark shell you will also need a `SparkContext`. To learn more about getting started with Spark refer to the [Spark Quick Start Guide](quick-start.html). -# The Property Graph +# The Property Graph + The [property graph](api/scala/index.html#org.apache.spark.graphx.Graph) is a directed multigraph with user defined objects attached to each vertex and edge. A directed multigraph is a directed graph with potentially multiple parallel edges sharing the same source and destination vertex. The @@ -123,7 +112,7 @@ identifiers. The property graph is parameterized over the vertex (`VD`) and edge (`ED`) types. These are the types of the objects associated with each vertex and edge respectively. -> GraphX optimizes the representation of vertex and edge types when they are plain old data-types +> GraphX optimizes the representation of vertex and edge types when they are primitive data types > (e.g., int, double, etc...) reducing the in memory footprint by storing them in specialized > arrays. @@ -142,8 +131,8 @@ var graph: Graph[VertexProperty, String] = null Like RDDs, property graphs are immutable, distributed, and fault-tolerant. Changes to the values or structure of the graph are accomplished by producing a new graph with the desired changes. Note that substantial parts of the original graph (i.e., unaffected structure, attributes, and indicies) -are reused in the new graph reducing the cost of this inherently functional data-structure. The -graph is partitioned across the executors using a range of vertex-partitioning heuristics. As with +are reused in the new graph reducing the cost of this inherently functional data structure. The +graph is partitioned across the executors using a range of vertex partitioning heuristics. As with RDDs, each partition of the graph can be recreated on a different machine in the event of a failure. Logically the property graph corresponds to a pair of typed collections (RDDs) encoding the @@ -153,12 +142,12 @@ the vertices and edges of the graph: {% highlight scala %} class Graph[VD, ED] { val vertices: VertexRDD[VD] - val edges: EdgeRDD[ED, VD] + val edges: EdgeRDD[ED] } {% endhighlight %} -The classes `VertexRDD[VD]` and `EdgeRDD[ED, VD]` extend and are optimized versions of `RDD[(VertexID, -VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED, VD]` provide additional +The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID, +VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional functionality built around graph computation and leverage internal optimizations. We discuss the `VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form: @@ -211,7 +200,6 @@ In the above example we make use of the [`Edge`][Edge] case class. Edges have a `dstId` corresponding to the source and destination vertex identifiers. In addition, the `Edge` class has an `attr` member which stores the edge property. -[Edge]: api/scala/index.html#org.apache.spark.graphx.Edge We can deconstruct a graph into the respective vertex and edge views by using the `graph.vertices` and `graph.edges` members respectively. @@ -237,7 +225,6 @@ The triplet view logically joins the vertex and edge properties yielding an `RDD[EdgeTriplet[VD, ED]]` containing instances of the [`EdgeTriplet`][EdgeTriplet] class. This *join* can be expressed in the following SQL expression: -[EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet {% highlight sql %} SELECT src.id, dst.id, src.attr, e.attr, dst.attr @@ -278,9 +265,6 @@ core operators are defined in [`GraphOps`][GraphOps]. However, thanks to Scala operators in `GraphOps` are automatically available as members of `Graph`. For example, we can compute the in-degree of each vertex (defined in `GraphOps`) by the following: -[Graph]: api/scala/index.html#org.apache.spark.graphx.Graph -[GraphOps]: api/scala/index.html#org.apache.spark.graphx.GraphOps - {% highlight scala %} val graph: Graph[(String, String), String] // Use the implicit GraphOps.inDegrees operator @@ -310,7 +294,7 @@ class Graph[VD, ED] { val degrees: VertexRDD[Int] // Views of the graph as collections ============================================================= val vertices: VertexRDD[VD] - val edges: EdgeRDD[ED, VD] + val edges: EdgeRDD[ED] val triplets: RDD[EdgeTriplet[VD, ED]] // Functions for caching graphs ================================================================== def persist(newLevel: StorageLevel = StorageLevel.MEMORY_ONLY): Graph[VD, ED] @@ -341,10 +325,10 @@ class Graph[VD, ED] { // Aggregate information about adjacent triplets ================================================= def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] - def mapReduceTriplets[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)], - reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) + def aggregateMessages[Msg: ClassTag]( + sendMsg: EdgeContext[VD, ED, Msg] => Unit, + mergeMsg: (Msg, Msg) => Msg, + tripletFields: TripletFields = TripletFields.All) : VertexRDD[A] // Iterative graph-parallel computation ========================================================== def pregel[A](initialMsg: A, maxIterations: Int, activeDirection: EdgeDirection)( @@ -363,8 +347,7 @@ class Graph[VD, ED] { ## Property Operators -In direct analogy to the RDD `map` operator, the property -graph contains the following: +Like the RDD `map` operator, the property graph contains the following: {% highlight scala %} class Graph[VD, ED] { @@ -377,7 +360,7 @@ class Graph[VD, ED] { Each of these operators yields a new graph with the vertex or edge properties modified by the user defined `map` function. -> Note that in all cases the graph structure is unaffected. This is a key feature of these operators +> Note that in each case the graph structure is unaffected. This is a key feature of these operators > which allows the resulting graph to reuse the structural indices of the original graph. The > following snippets are logically equivalent, but the first one does not preserve the structural > indices and would not benefit from the GraphX system optimizations: @@ -390,14 +373,13 @@ val newGraph = Graph(newVertices, graph.edges) val newGraph = graph.mapVertices((id, attr) => mapUdf(id, attr)) {% endhighlight %} -[Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED] These operators are often used to initialize the graph for a particular computation or project away -unnecessary properties. For example, given a graph with the out-degrees as the vertex properties +unnecessary properties. For example, given a graph with the out degrees as the vertex properties (we describe how to construct such a graph later), we initialize it for PageRank: {% highlight scala %} -// Given a graph where the vertex property is the out-degree +// Given a graph where the vertex property is the out degree val inputGraph: Graph[Int, String] = graph.outerJoinVertices(graph.outDegrees)((vid, _, degOpt) => degOpt.getOrElse(0)) // Construct a graph where each edge contains the weight @@ -406,9 +388,10 @@ val outputGraph: Graph[Double, Double] = inputGraph.mapTriplets(triplet => 1.0 / triplet.srcAttr).mapVertices((id, _) => 1.0) {% endhighlight %} -## Structural Operators +## Structural Operators + Currently GraphX supports only a simple set of commonly used structural operators and we expect to add more in the future. The following is a list of the basic structural operators. @@ -425,9 +408,8 @@ class Graph[VD, ED] { The [`reverse`][Graph.reverse] operator returns a new graph with all the edge directions reversed. This can be useful when, for example, trying to compute the inverse PageRank. Because the reverse operation does not modify vertex or edge properties or change the number of edges, it can be -implemented efficiently without data-movement or duplication. +implemented efficiently without data movement or duplication. -[Graph.reverse]: api/scala/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED] The [`subgraph`][Graph.subgraph] operator takes vertex and edge predicates and returns the graph containing only the vertices that satisfy the vertex predicate (evaluate to true) and edges that @@ -435,7 +417,6 @@ satisfy the edge predicate *and connect vertices that satisfy the vertex predica operator can be used in number of situations to restrict the graph to the vertices and edges of interest or eliminate broken links. For example in the following code we remove broken links: -[Graph.subgraph]: api/scala/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexId,VD)⇒Boolean):Graph[VD,ED] {% highlight scala %} // Create an RDD for the vertices @@ -469,13 +450,12 @@ validGraph.triplets.map( > Note in the above example only the vertex predicate is provided. The `subgraph` operator defaults > to `true` if the vertex or edge predicates are not provided. -The [`mask`][Graph.mask] operator also constructs a subgraph by returning a graph that contains the +The [`mask`][Graph.mask] operator constructs a subgraph by returning a graph that contains the vertices and edges that are also found in the input graph. This can be used in conjunction with the `subgraph` operator to restrict a graph based on the properties in another related graph. For example, we might run connected components using the graph with missing vertices and then restrict the answer to the valid subgraph. -[Graph.mask]: api/scala/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED] {% highlight scala %} // Run Connected Components @@ -490,10 +470,9 @@ The [`groupEdges`][Graph.groupEdges] operator merges parallel edges (i.e., dupli pairs of vertices) in the multigraph. In many numerical applications, parallel edges can be *added* (their weights combined) into a single edge thereby reducing the size of the graph. -[Graph.groupEdges]: api/scala/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED] + ## Join Operators - In many cases it is necessary to join data from external collections (RDDs) with graphs. For example, we might have extra user properties that we want to merge with an existing graph or we @@ -514,10 +493,8 @@ returns a new graph with the vertex properties obtained by applying the user def to the result of the joined vertices. Vertices without a matching value in the RDD retain their original value. -[GraphOps.joinVertices]: api/scala/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexId,U)])((VertexId,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED] - -> Note that if the RDD contains more than one value for a given vertex only one will be used. It -> is therefore recommended that the input RDD be first made unique using the following which will +> Note that if the RDD contains more than one value for a given vertex only one will be used. It +> is therefore recommended that the input RDD be made unique using the following which will > also *pre-index* the resulting values to substantially accelerate the subsequent join. > {% highlight scala %} val nonUniqueCosts: RDD[(VertexID, Double)] @@ -533,8 +510,6 @@ property type. Because not all vertices may have a matching value in the input function takes an `Option` type. For example, we can setup a graph for PageRank by initializing vertex properties with their `outDegree`. -[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] - {% highlight scala %} val outDegrees: VertexRDD[Int] = graph.outDegrees @@ -555,65 +530,76 @@ val joinedGraph = graph.joinVertices(uniqueCosts, (id: VertexID, oldCost: Double, extraCost: Double) => oldCost + extraCost) {% endhighlight %} +> + + ## Neighborhood Aggregation -A key part of graph computation is aggregating information about the neighborhood of each vertex. -For example we might want to know the number of followers each user has or the average age of the +A key step in may graph analytics tasks is aggregating information about the neighborhood of each +vertex. +For example, we might want to know the number of followers each user has or the average age of the the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and connected components) repeatedly aggregate properties of neighboring vertices (e.g., current PageRank Value, shortest path to the source, and smallest reachable vertex id). -### Map Reduce Triplets (mapReduceTriplets) - +> To improve performance the primary aggregation operator changed from +`graph.mapReduceTriplets` to the new `graph.AggregateMessages`. While the changes in the API are +relatively small, we provide a transition guide below. -[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A] + -The core (heavily optimized) aggregation primitive in GraphX is the -[`mapReduceTriplets`][Graph.mapReduceTriplets] operator: +### Aggregate Messages (aggregateMessages) + +The core aggregation operation in GraphX is [`aggregateMessages`][Graph.aggregateMessages]. +This operator applies a user defined `sendMsg` function to each edge triplet in the graph +and then uses the `mergeMsg` function to aggregate those messages at their destination vertex. {% highlight scala %} class Graph[VD, ED] { - def mapReduceTriplets[A]( - map: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduce: (A, A) => A) - : VertexRDD[A] + def aggregateMessages[Msg: ClassTag]( + sendMsg: EdgeContext[VD, ED, Msg] => Unit, + mergeMsg: (Msg, Msg) => Msg, + tripletFields: TripletFields = TripletFields.All) + : VertexRDD[Msg] } {% endhighlight %} -The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which -is applied to each triplet and can yield *messages* destined to either (none or both) vertices in -the triplet. To facilitate optimized pre-aggregation, we currently only support messages destined -to the source or destination vertex of the triplet. The user defined `reduce` function combines the -messages destined to each vertex. The `mapReduceTriplets` operator returns a `VertexRDD[A]` -containing the aggregate message (of type `A`) destined to each vertex. Vertices that do not +The user defined `sendMsg` function takes an [`EdgeContext`][EdgeContext], which exposes the +source and destination attributes along with the edge attribute and functions +([`sendToSrc`][EdgeContext.sendToSrc], and [`sendToDst`][EdgeContext.sendToDst]) to send +messages to the source and destination attributes. Think of `sendMsg` as the map +function in map-reduce. +The user defined `mergeMsg` function takes two messages destined to the same vertex and +yields a single message. Think of `mergeMsg` as the reduce function in map-reduce. +The [`aggregateMessages`][Graph.aggregateMessages] operator returns a `VertexRDD[Msg]` +containing the aggregate message (of type `Msg`) destined to each vertex. Vertices that did not receive a message are not included in the returned `VertexRDD`. -
    - -

    Note that mapReduceTriplets takes an additional optional activeSet -(not shown above see API docs for details) which restricts the map phase to edges adjacent to the -vertices in the provided VertexRDD:

    - -{% highlight scala %} - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None -{% endhighlight %} - -

    The EdgeDirection specifies which edges adjacent to the vertex set are included in the map -phase. If the direction is In, then the user defined map function will -only be run only on edges with the destination vertex in the active set. If the direction is -Out, then the map function will only be run only on edges originating from -vertices in the active set. If the direction is Either, then the map -function will be run only on edges with either vertex in the active set. If the direction is -Both, then the map function will be run only on edges with both vertices -in the active set. The active set must be derived from the set of vertices in the graph. -Restricting computation to triplets adjacent to a subset of the vertices is often necessary in -incremental iterative computation and is a key part of the GraphX implementation of Pregel.

    - -
    - -In the following example we use the `mapReduceTriplets` operator to compute the average age of the -more senior followers of each user. + + +In addition, [`aggregateMessages`][Graph.aggregateMessages] takes an optional +`tripletsFields` which indicates what data is accessed in the [`EdgeContext`][EdgeContext] +(i.e., the source vertex attribute but not the destination vertex attribute). +The possible options for the `tripletsFields` are defined in [`TripletFields`][TripletFields] and +the default value is [`TripletFields.All`][TripletFields.All] which indicates that the user +defined `sendMsg` function may access any of the fields in the [`EdgeContext`][EdgeContext]. +The `tripletFields` argument can be used to notify GraphX that only part of the +[`EdgeContext`][EdgeContext] will be needed allowing GraphX to select an optimized join strategy. +For example if we are computing the average age of the followers of each user we would only require +the source field and so we would use [`TripletFields.Src`][TripletFields.Src] to indicate that we +only require the source field + +> In earlier versions of GraphX we used byte code inspection to infer the +[`TripletFields`][TripletFields] however we have found that bytecode inspection to be +slightly unreliable and instead opted for more explicit user control. + +In the following example we use the [`aggregateMessages`][Graph.aggregateMessages] operator to +compute the average age of the more senior followers of each user. {% highlight scala %} // Import random graph generation library @@ -622,14 +608,11 @@ import org.apache.spark.graphx.util.GraphGenerators val graph: Graph[Double, Int] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble ) // Compute the number of older followers and their total age -val olderFollowers: VertexRDD[(Int, Double)] = graph.mapReduceTriplets[(Int, Double)]( +val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)]( triplet => { // Map Function if (triplet.srcAttr > triplet.dstAttr) { // Send message to destination vertex containing counter and age - Iterator((triplet.dstId, (1, triplet.srcAttr))) - } else { - // Don't send a message for this triplet - Iterator.empty + triplet.sendToDst(1, triplet.srcAttr) } }, // Add counter and age @@ -642,10 +625,57 @@ val avgAgeOfOlderFollowers: VertexRDD[Double] = avgAgeOfOlderFollowers.collect.foreach(println(_)) {% endhighlight %} -> Note that the `mapReduceTriplets` operation performs optimally when the messages (and the sums of -> messages) are constant sized (e.g., floats and addition instead of lists and concatenation). More -> precisely, the result of `mapReduceTriplets` should ideally be sub-linear in the degree of each -> vertex. +> The `aggregateMessages` operation performs optimally when the messages (and the sums of +> messages) are constant sized (e.g., floats and addition instead of lists and concatenation). + + + +### Map Reduce Triplets Transition Guide (Legacy) + +In earlier versions of GraphX we neighborhood aggregation was accomplished using the +[`mapReduceTriplets`][Graph.mapReduceTriplets] operator: + +{% highlight scala %} +class Graph[VD, ED] { + def mapReduceTriplets[Msg]( + map: EdgeTriplet[VD, ED] => Iterator[(VertexId, Msg)], + reduce: (Msg, Msg) => Msg) + : VertexRDD[Msg] +} +{% endhighlight %} + +The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which +is applied to each triplet and can yield *messages* which are aggregated using the user defined +`reduce` function. +However, we found the user of the returned iterator to be expensive and it inhibited our ability to +apply additional optimizations (e.g., local vertex renumbering). +In [`aggregateMessages`][Graph.aggregateMessages] we introduced the EdgeContext which exposes the +triplet fields and also functions to explicitly send messages to the source and destination vertex. +Furthermore we removed bytecode inspection and instead require the user to indicate what fields +in the triplet are actually required. + +The following code block using `mapReduceTriplets`: + +{% highlight scala %} +val graph: Graph[Int, Float] = ... +def msgFun(triplet: Triplet[Int, Float]): Iterator[(Int, String)] = { + Iterator((triplet.dstId, "Hi")) +} +def reduceFun(a: Int, b: Int): Int = a + b +val result = graph.mapReduceTriplets[String](msgFun, reduceFun) +{% endhighlight %} + +can be rewritten using `aggregateMessages` as: + +{% highlight scala %} +val graph: Graph[Int, Float] = ... +def msgFun(triplet: EdgeContext[Int, Float, String]) { + triplet.sendToDst("Hi") +} +def reduceFun(a: Int, b: Int): Int = a + b +val result = graph.aggregateMessages[String](msgFun, reduceFun) +{% endhighlight %} + ### Computing Degree Information @@ -673,10 +703,6 @@ attributes at each vertex. This can be easily accomplished using the [`collectNeighborIds`][GraphOps.collectNeighborIds] and the [`collectNeighbors`][GraphOps.collectNeighbors] operators. -[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]] -[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]] - - {% highlight scala %} class GraphOps[VD, ED] { def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] @@ -684,36 +710,34 @@ class GraphOps[VD, ED] { } {% endhighlight %} -> Note that these operators can be quite costly as they duplicate information and require +> These operators can be quite costly as they duplicate information and require > substantial communication. If possible try expressing the same computation using the -> `mapReduceTriplets` operator directly. +> [`aggregateMessages`][Graph.aggregateMessages] operator directly. ## Caching and Uncaching In Spark, RDDs are not persisted in memory by default. To avoid recomputation, they must be explicitly cached when using them multiple times (see the [Spark Programming Guide][RDD Persistence]). Graphs in GraphX behave the same way. **When using a graph multiple times, make sure to call [`Graph.cache()`][Graph.cache] on it first.** -[RDD Persistence]: programming-guide.html#rdd-persistence -[Graph.cache]: api/scala/index.html#org.apache.spark.graphx.Graph@cache():Graph[VD,ED] In iterative computations, *uncaching* may also be necessary for best performance. By default, cached RDDs and graphs will remain in memory until memory pressure forces them to be evicted in LRU order. For iterative computation, intermediate results from previous iterations will fill up the cache. Though they will eventually be evicted, the unnecessary data stored in memory will slow down garbage collection. It would be more efficient to uncache intermediate results as soon as they are no longer necessary. This involves materializing (caching and forcing) a graph or RDD every iteration, uncaching all other datasets, and only using the materialized dataset in future iterations. However, because graphs are composed of multiple RDDs, it can be difficult to unpersist them correctly. **For iterative computation we recommend using the Pregel API, which correctly unpersists intermediate results.** -# Pregel API -Graphs are inherently recursive data-structures as properties of vertices depend on properties of +# Pregel API + +Graphs are inherently recursive data structures as properties of vertices depend on properties of their neighbors which in turn depend on properties of *their* neighbors. As a consequence many important graph algorithms iteratively recompute the properties of each vertex until a fixed-point condition is reached. A range of graph-parallel abstractions have been proposed -to express these iterative algorithms. GraphX exposes a Pregel-like operator which is a fusion of -the widely used Pregel and GraphLab abstractions. +to express these iterative algorithms. GraphX exposes a variant of the Pregel API. -At a high-level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction -*constrained to the topology of the graph*. The Pregel operator executes in a series of super-steps -in which vertices receive the *sum* of their inbound messages from the previous super- step, compute +At a high level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction +*constrained to the topology of the graph*. The Pregel operator executes in a series of super steps +in which vertices receive the *sum* of their inbound messages from the previous super step, compute a new value for the vertex property, and then send messages to neighboring vertices in the next -super-step. Unlike Pregel and instead more like GraphLab messages are computed in parallel as a +super step. Unlike Pregel, messages are computed in parallel as a function of the edge triplet and the message computation has access to both the source and -destination vertex attributes. Vertices that do not receive a message are skipped within a super- +destination vertex attributes. Vertices that do not receive a message are skipped within a super step. The Pregel operators terminates iteration and returns the final graph when there are no messages remaining. @@ -724,8 +748,6 @@ messages remaining. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* of its implementation (note calls to graph.cache have been removed): -[GraphOps.pregel]: api/scala/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexId,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexId,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED] - {% highlight scala %} class GraphOps[VD, ED] { def pregel[A] @@ -795,9 +817,10 @@ val sssp = initialGraph.pregel(Double.PositiveInfinity)( println(sssp.vertices.collect.mkString("\n")) {% endhighlight %} -# Graph Builders +# Graph Builders + GraphX provides several ways of building a graph from a collection of vertices and edges in an RDD or on disk. None of the graph builders repartitions the graph's edges by default; instead, edges are left in their default partitions (such as their original blocks in HDFS). [`Graph.groupEdges`][Graph.groupEdges] requires the graph to be repartitioned because it assumes identical edges will be colocated on the same partition, so you must call [`Graph.partitionBy`][Graph.partitionBy] before calling `groupEdges`. {% highlight scala %} @@ -848,18 +871,12 @@ object Graph { [`Graph.fromEdgeTuples`][Graph.fromEdgeTuples] allows creating a graph from only an RDD of edge tuples, assigning the edges the value 1, and automatically creating any vertices mentioned by edges and assigning them the default value. It also supports deduplicating the edges; to deduplicate, pass `Some` of a [`PartitionStrategy`][PartitionStrategy] as the `uniqueEdges` parameter (for example, `uniqueEdges = Some(PartitionStrategy.RandomVertexCut)`). A partition strategy is necessary to colocate identical edges on the same partition so they can be deduplicated. -[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy$ - -[GraphLoader.edgeListFile]: api/scala/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int] -[Graph.apply]: api/scala/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexId,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] -[Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] -[Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] + # Vertex and Edge RDDs - GraphX exposes `RDD` views of the vertices and edges stored within the graph. However, because -GraphX maintains the vertices and edges in optimized data-structures and these data-structures +GraphX maintains the vertices and edges in optimized data structures and these data structures provide additional functionality, the vertices and edges are returned as `VertexRDD` and `EdgeRDD` respectively. In this section we review some of the additional useful functionality in these types. @@ -870,7 +887,7 @@ The `VertexRDD[A]` extends `RDD[(VertexID, A)]` and adds the additional constrai attribute of type `A`. Internally, this is achieved by storing the vertex attributes in a reusable hash-map data-structure. As a consequence if two `VertexRDD`s are derived from the same base `VertexRDD` (e.g., by `filter` or `mapValues`) they can be joined in constant time without hash -evaluations. To leverage this indexed data-structure, the `VertexRDD` exposes the following +evaluations. To leverage this indexed data structure, the `VertexRDD` exposes the following additional functionality: {% highlight scala %} @@ -893,7 +910,7 @@ class VertexRDD[VD] extends RDD[(VertexID, VD)] { Notice, for example, how the `filter` operator returns an `VertexRDD`. Filter is actually implemented using a `BitSet` thereby reusing the index and preserving the ability to do fast joins with other `VertexRDD`s. Likewise, the `mapValues` operators do not allow the `map` function to -change the `VertexID` thereby enabling the same `HashMap` data-structures to be reused. Both the +change the `VertexID` thereby enabling the same `HashMap` data structures to be reused. Both the `leftJoin` and `innerJoin` are able to identify when joining two `VertexRDD`s derived from the same `HashMap` and implement the join by linear scan rather than costly point lookups. @@ -916,21 +933,19 @@ val setC: VertexRDD[Double] = setA.innerJoin(setB)((id, a, b) => a + b) ## EdgeRDDs -The `EdgeRDD[ED, VD]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one +The `EdgeRDD[ED]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one of the various partitioning strategies defined in [`PartitionStrategy`][PartitionStrategy]. Within each partition, edge attributes and adjacency structure, are stored separately enabling maximum reuse when changing attribute values. -[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy - The three additional functions exposed by the `EdgeRDD` are: {% highlight scala %} // Transform the edge attributes while preserving the structure -def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] +def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2] // Revere the edges reusing both attributes and structure -def reverse: EdgeRDD[ED, VD] +def reverse: EdgeRDD[ED] // Join two `EdgeRDD`s partitioned using the same partitioning strategy. -def innerJoin[ED2, ED3](other: EdgeRDD[ED2, VD])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] +def innerJoin[ED2, ED3](other: EdgeRDD[ED2])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] {% endhighlight %} In most applications we have found that operations on the `EdgeRDD` are accomplished through the @@ -960,7 +975,6 @@ the [`Graph.partitionBy`][Graph.partitionBy] operator. The default partitioning the initial partitioning of the edges as provided on graph construction. However, users can easily switch to 2D-partitioning or other heuristics included in GraphX. -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED]

    +# Graph Algorithms + GraphX includes a set of graph algorithms to simplify analytics tasks. The algorithms are contained in the `org.apache.spark.graphx.lib` package and can be accessed directly as methods on `Graph` via [`GraphOps`][GraphOps]. This section describes the algorithms and how they are used. -## PageRank +## PageRank + PageRank measures the importance of each vertex in a graph, assuming an edge from *u* to *v* represents an endorsement of *v*'s importance by *u*. For example, if a Twitter user is followed by many others, the user will be ranked highly. GraphX comes with static and dynamic implementations of PageRank as methods on the [`PageRank` object][PageRank]. Static PageRank runs for a fixed number of iterations, while dynamic PageRank runs until the ranks converge (i.e., stop changing by more than a specified tolerance). [`GraphOps`][GraphOps] allows calling these algorithms directly as methods on `Graph`. GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `graphx/data/users.txt`, and a set of relationships between users is given in `graphx/data/followers.txt`. We compute the PageRank of each user as follows: -[PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ - {% highlight scala %} // Load the edges as a graph val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") @@ -1014,8 +1028,6 @@ println(ranksByUsername.collect().mkString("\n")) The connected components algorithm labels each connected component of the graph with the ID of its lowest-numbered vertex. For example, in a social network, connected components can approximate clusters. GraphX contains an implementation of the algorithm in the [`ConnectedComponents` object][ConnectedComponents], and we compute the connected components of the example social network dataset from the [PageRank section](#pagerank) as follows: -[ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ - {% highlight scala %} // Load the graph as in the PageRank example val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") @@ -1037,9 +1049,6 @@ println(ccByUsername.collect().mkString("\n")) A vertex is part of a triangle when it has two adjacent vertices with an edge between them. GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount] that determines the number of triangles passing through each vertex, providing a measure of clustering. We compute the triangle count of the social network dataset from the [PageRank section](#pagerank). *Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`) and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy].* -[TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED] - {% highlight scala %} // Load the edges in canonical order and partition the graph for triangle count val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt", true).partitionBy(PartitionStrategy.RandomVertexCut) diff --git a/docs/img/data_parallel_vs_graph_parallel.png b/docs/img/data_parallel_vs_graph_parallel.png deleted file mode 100644 index d3918f01d8f3b..0000000000000 Binary files a/docs/img/data_parallel_vs_graph_parallel.png and /dev/null differ diff --git a/docs/img/graph_analytics_pipeline.png b/docs/img/graph_analytics_pipeline.png deleted file mode 100644 index 6d606e01894ae..0000000000000 Binary files a/docs/img/graph_analytics_pipeline.png and /dev/null differ diff --git a/docs/img/tables_and_graphs.png b/docs/img/tables_and_graphs.png deleted file mode 100644 index ec37bb45a62f0..0000000000000 Binary files a/docs/img/tables_and_graphs.png and /dev/null differ diff --git a/docs/monitoring.md b/docs/monitoring.md index e3f81a76acdbb..f32cdef240d31 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -79,7 +79,7 @@ follows: spark.history.fs.logDirectory - (none) + file:/tmp/spark-events Directory that contains application event logs to be loaded by the history server diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 18420afb27e3c..7a16ee8742dc0 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -117,6 +117,8 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/ how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object that contains information about your application. +Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one. + {% highlight scala %} val conf = new SparkConf().setAppName(appName).setMaster(master) new SparkContext(conf) @@ -1131,7 +1133,7 @@ method. The code below shows this: {% highlight scala %} scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) -broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c) +broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0) scala> broadcastVar.value res0: Array[Int] = Array(1, 2, 3) @@ -1175,7 +1177,7 @@ Accumulators are variables that are only "added" to through an associative opera therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be -displayed in Spark's UI. This can can be useful for understanding the progress of +displayed in Spark's UI. This can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python). An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks @@ -1304,6 +1306,12 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam()) +For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator +will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware +of that each task's update may be applied more than once if tasks or job stages are re-executed. + + + # Deploying to a Cluster The [application submission guide](submitting-applications.html) describes how to submit applications to a cluster. diff --git a/docs/quick-start.md b/docs/quick-start.md index 6236de0e1f2c4..bf643bb70e153 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -244,6 +244,9 @@ object SimpleApp { } {% endhighlight %} +Note that applications should define a `main()` method instead of extending `scala.App`. +Subclasses of `scala.App` may not work correctly. + This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 695813a2ba881..dfe2db4b3fce8 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -4,7 +4,7 @@ title: Running Spark on YARN --- Support for running on [YARN (Hadoop -NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html) +NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. # Preparations @@ -39,7 +39,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.preserve.staging.files false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather then delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. @@ -159,7 +159,7 @@ For example: lib/spark-examples*.jar \ 10 -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs. +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: @@ -181,7 +181,7 @@ In YARN terminology, executors and application masters run inside "containers". yarn logs -applicationId -will print out the contents of all log files from all containers from the given application. +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ffcce2c588879..24a68bb083334 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -14,7 +14,7 @@ title: Spark SQL Programming Guide Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using Spark. At the core of this component is a new type of RDD, [SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects, along with +[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). @@ -728,7 +728,7 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -In order to use Hive you must first run "`sbt/sbt -Phive assembly/assembly`" (or use `-Phive` for maven). +Hive support is enabled by adding the `-Phive` and `-Phive-thriftserver` flags to Spark's build. This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. @@ -900,7 +900,6 @@ export HIVE_SERVER2_THRIFT_BIND_HOST= ./sbin/start-thriftserver.sh \ --master \ ... -``` {% endhighlight %} or system properties: @@ -911,7 +910,6 @@ or system properties: --hiveconf hive.server2.thrift.bind.host= \ --master ... -``` {% endhighlight %} Now you can use beeline to test the Thrift JDBC/ODBC server: diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index a5396c2375915..b83decadc2988 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -33,6 +33,7 @@ import time import urllib2 import warnings +from datetime import datetime from optparse import OptionParser from sys import stderr import boto @@ -86,7 +87,7 @@ def parse_args(): "-z", "--zone", default="", help="Availability zone to launch instances in, or 'all' to spread " + "slaves across multiple (an additional $0.01/Gb for bandwidth" + - "between zones applies)") + "between zones applies) (default: a single zone chosen at random)") parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") parser.add_option( "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, @@ -138,7 +139,7 @@ def parse_args(): help="The SSH user you want to connect as (default: %default)") parser.add_option( "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created.") + help="When destroying a cluster, delete the security groups that were created") parser.add_option( "--use-existing-master", action="store_true", default=False, help="Launch fresh slaves, but use an existing stopped master if possible") @@ -152,9 +153,6 @@ def parse_args(): parser.add_option( "--user-data", type="string", default="", help="Path to a user-data file (most AMI's interpret this as an initialization script)") - parser.add_option( - "--security-group-prefix", type="string", default=None, - help="Use this prefix for the security group rather than the cluster name.") parser.add_option( "--authorized-address", type="string", default="0.0.0.0/0", help="Address to authorize on created security groups (default: %default)") @@ -305,12 +303,8 @@ def launch_cluster(conn, opts, cluster_name): user_data_content = user_data_file.read() print "Setting up security groups..." - if opts.security_group_prefix is None: - master_group = get_or_make_group(conn, cluster_name + "-master") - slave_group = get_or_make_group(conn, cluster_name + "-slaves") - else: - master_group = get_or_make_group(conn, opts.security_group_prefix + "-master") - slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves") + master_group = get_or_make_group(conn, cluster_name + "-master") + slave_group = get_or_make_group(conn, cluster_name + "-slaves") authorized_address = opts.authorized_address if master_group.rules == []: # Group was just now created master_group.authorize(src_group=master_group) @@ -335,11 +329,12 @@ def launch_cluster(conn, opts, cluster_name): slave_group.authorize('tcp', 60060, 60060, authorized_address) slave_group.authorize('tcp', 60075, 60075, authorized_address) - # Check if instances are already running with the cluster name + # Check if instances are already running in our groups existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances for name: %s " % cluster_name) + print >> stderr, ("ERROR: There are already instances running in " + + "group %s or %s" % (master_group.name, slave_group.name)) sys.exit(1) # Figure out Spark AMI @@ -413,13 +408,9 @@ def launch_cluster(conn, opts, cluster_name): for r in reqs: id_to_req[r.id] = r active_instance_ids = [] - outstanding_request_ids = [] for i in my_req_ids: - if i in id_to_req: - if id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) - else: - outstanding_request_ids.append(i) + if i in id_to_req and id_to_req[i].state == "active": + active_instance_ids.append(id_to_req[i].instance_id) if len(active_instance_ids) == opts.slaves: print "All %d slaves granted" % opts.slaves reservations = conn.get_all_instances(active_instance_ids) @@ -428,8 +419,8 @@ def launch_cluster(conn, opts, cluster_name): slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer for request ids including %s" % ( - len(active_instance_ids), opts.slaves, outstanding_request_ids[0:10]) + print "%d of %d slaves granted, waiting longer" % ( + len(active_instance_ids), opts.slaves) except: print "Canceling spot instance requests" conn.cancel_spot_instance_requests(my_req_ids) @@ -488,59 +479,34 @@ def launch_cluster(conn, opts, cluster_name): # Give the instances descriptive names for master in master_nodes: - name = '{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id) - tag_instance(master, name) - + master.add_tag( + key='Name', + value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) for slave in slave_nodes: - name = '{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id) - tag_instance(slave, name) + slave.add_tag( + key='Name', + value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) # Return all the instances return (master_nodes, slave_nodes) -def tag_instance(instance, name): - for i in range(0, 5): - try: - instance.add_tag(key='Name', value=name) - break - except: - print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) - if i == 5: - raise "Error - failed max attempts to add name tag" - time.sleep(5) - # Get the EC2 instances in an existing cluster if available. # Returns a tuple of lists of EC2 instance objects for the masters and slaves def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): print "Searching for existing cluster " + cluster_name + "..." - # Search all the spot instance requests, and copy any tags from the spot - # instance request to the cluster. - spot_instance_requests = conn.get_all_spot_instance_requests() - for req in spot_instance_requests: - if req.state != u'active': - continue - name = req.tags.get(u'Name', "") - if name.startswith(cluster_name): - reservations = conn.get_all_instances(instance_ids=[req.instance_id]) - for res in reservations: - active = [i for i in res.instances if is_active(i)] - for instance in active: - if instance.tags.get(u'Name') is None: - tag_instance(instance, name) - # Now proceed to detect master and slaves instances. reservations = conn.get_all_instances() master_nodes = [] slave_nodes = [] for res in reservations: active = [i for i in res.instances if is_active(i)] for inst in active: - name = inst.tags.get(u'Name', "") - if name.startswith(cluster_name + "-master"): + group_names = [g.name for g in inst.groups] + if group_names == [cluster_name + "-master"]: master_nodes.append(inst) - elif name.startswith(cluster_name + "-slave"): + elif group_names == [cluster_name + "-slaves"]: slave_nodes.append(inst) if any((master_nodes, slave_nodes)): print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) @@ -548,12 +514,12 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in with name " + \ - cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" else: print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) + # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. @@ -624,7 +590,9 @@ def setup_spark_cluster(master, opts): def is_ssh_available(host, opts): - "Checks if SSH is available on the host." + """ + Check if SSH is available on a host. + """ try: with open(os.devnull, 'w') as devnull: ret = subprocess.check_call( @@ -639,6 +607,9 @@ def is_ssh_available(host, opts): def is_cluster_ssh_available(cluster_instances, opts): + """ + Check if SSH is available on all the instances in a cluster. + """ for i in cluster_instances: if not is_ssh_available(host=i.ip_address, opts=opts): return False @@ -646,8 +617,10 @@ def is_cluster_ssh_available(cluster_instances, opts): return True -def wait_for_cluster_state(cluster_instances, cluster_state, opts): +def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): """ + Wait for all the instances in the cluster to reach a designated state. + cluster_instances: a list of boto.ec2.instance.Instance cluster_state: a string representing the desired state of all the instances in the cluster value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as @@ -655,20 +628,27 @@ def wait_for_cluster_state(cluster_instances, cluster_state, opts): (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) """ sys.stdout.write( - "Waiting for all instances in cluster to enter '{s}' state.".format(s=cluster_state) + "Waiting for cluster to enter '{s}' state.".format(s=cluster_state) ) sys.stdout.flush() + start_time = datetime.now() + num_attempts = 0 + conn = ec2.connect_to_region(opts.region) while True: - time.sleep(3 * num_attempts) + time.sleep(5 * num_attempts) # seconds for i in cluster_instances: - s = i.update() # capture output to suppress print to screen in newer versions of boto + i.update() + + statuses = conn.get_all_instance_status(instance_ids=[i.id for i in cluster_instances]) if cluster_state == 'ssh-ready': if all(i.state == 'running' for i in cluster_instances) and \ + all(s.system_status.status == 'ok' for s in statuses) and \ + all(s.instance_status.status == 'ok' for s in statuses) and \ is_cluster_ssh_available(cluster_instances, opts): break else: @@ -682,6 +662,12 @@ def wait_for_cluster_state(cluster_instances, cluster_state, opts): sys.stdout.write("\n") + end_time = datetime.now() + print "Cluster is now in '{s}' state. Waited {t} seconds.".format( + s=cluster_state, + t=(end_time - start_time).seconds + ) + # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): @@ -930,7 +916,7 @@ def real_main(): # See: https://docs.python.org/3.5/whatsnew/2.7.html warnings.warn( "This option is deprecated and has no effect. " - "spark-ec2 automatically waits as long as necessary for clusters to startup.", + "spark-ec2 automatically waits as long as necessary for clusters to start up.", DeprecationWarning ) @@ -957,9 +943,10 @@ def real_main(): else: (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) wait_for_cluster_state( + conn=conn, + opts=opts, cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready', - opts=opts + cluster_state='ssh-ready' ) setup_cluster(conn, master_nodes, slave_nodes, opts, True) @@ -984,15 +971,12 @@ def real_main(): # Delete security groups as well if opts.delete_groups: print "Deleting security groups (this will take some time)..." - if opts.security_group_prefix is None: - group_names = [cluster_name + "-master", cluster_name + "-slaves"] - else: - group_names = [opts.security_group_prefix + "-master", - opts.security_group_prefix + "-slaves"] + group_names = [cluster_name + "-master", cluster_name + "-slaves"] wait_for_cluster_state( + conn=conn, + opts=opts, cluster_instances=(master_nodes + slave_nodes), - cluster_state='terminated', - opts=opts + cluster_state='terminated' ) attempt = 1 while attempt <= 3: @@ -1094,9 +1078,10 @@ def real_main(): if inst.state not in ["shutting-down", "terminated"]: inst.start() wait_for_cluster_state( + conn=conn, + opts=opts, cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready', - opts=opts + cluster_state='ssh-ready' ) setup_cluster(conn, master_nodes, slave_nodes, opts, False) diff --git a/examples/pom.xml b/examples/pom.xml index 910eb55308b9d..8713230e1e8ed 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -34,48 +34,6 @@ Spark Project Examples http://spark.apache.org/ - - - kinesis-asl - - - org.apache.spark - spark-streaming-kinesis-asl_${scala.binary.version} - ${project.version} - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - - - hbase-hadoop2 - - - hbase.profile - hadoop2 - - - - 0.98.7-hadoop2 - - - - hbase-hadoop1 - - - !hbase.profile - - - - 0.98.7-hadoop1 - - - - - @@ -124,11 +82,6 @@ spark-streaming-twitter_${scala.binary.version} ${project.version} - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - org.apache.spark spark-streaming-flume_${scala.binary.version} @@ -136,12 +89,12 @@ org.apache.spark - spark-streaming-zeromq_${scala.binary.version} + spark-streaming-mqtt_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming-mqtt_${scala.binary.version} + spark-streaming-zeromq_${scala.binary.version} ${project.version} @@ -153,6 +106,11 @@ hbase-testing-util ${hbase.version} + + + org.apache.hbase + hbase-annotations + org.jruby jruby-complete @@ -168,12 +126,24 @@ org.apache.hbase hbase-common ${hbase.version} + + + + org.apache.hbase + hbase-annotations + + org.apache.hbase hbase-client ${hbase.version} + + + org.apache.hbase + hbase-annotations + io.netty netty @@ -205,6 +175,11 @@ org.apache.hadoop hadoop-auth + + + org.apache.hbase + hbase-annotations + org.apache.hadoop hadoop-annotations @@ -260,15 +235,15 @@ test-jar test - - com.twitter - algebird-core_${scala.binary.version} - 0.1.11 - org.apache.commons commons-math3 + + com.twitter + algebird-core_${scala.binary.version} + 0.8.1 + org.scalatest scalatest_${scala.binary.version} @@ -401,4 +376,83 @@ + + + kinesis-asl + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + + + hbase-hadoop2 + + + hbase.profile + hadoop2 + + + + 0.98.7-hadoop2 + + + + hbase-hadoop1 + + + !hbase.profile + + + + 0.98.7-hadoop1 + + + + + scala-2.10 + + !scala-2.11 + + + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + scala-2.10/src/main/scala + scala-2.10/src/main/java + + + + + + + + + diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java similarity index 92% rename from examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java rename to examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java index 430e96ab14d9d..e68ec74c3ed54 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java @@ -31,7 +31,7 @@ /** * Example of using Spark's status APIs from Java. */ -public final class JavaStatusAPIDemo { +public final class JavaStatusTrackerDemo { public static final String APP_NAME = "JavaStatusAPIDemo"; @@ -58,8 +58,8 @@ public static void main(String[] args) throws Exception { continue; } int currentJobId = jobIds.get(jobIds.size() - 1); - SparkJobInfo jobInfo = sc.getJobInfo(currentJobId); - SparkStageInfo stageInfo = sc.getStageInfo(jobInfo.stageIds()[0]); + SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId); + SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]); System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() + " active, " + stageInfo.numCompletedTasks() + " complete"); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java new file mode 100644 index 0000000000000..22ba68d8c354c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import org.apache.spark.sql.api.java.Row; +import org.apache.spark.SparkConf; + +/** + * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java + * bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of + * this example {@link SimpleTextClassificationPipeline}. Run with + *

    + * bin/run-example ml.JavaSimpleTextClassificationPipeline
    + * 
    + */ +public class JavaSimpleTextClassificationPipeline { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); + JavaSparkContext jsc = new JavaSparkContext(conf); + JavaSQLContext jsql = new JavaSQLContext(jsc); + + // Prepare training documents, which are labeled. + List localTraining = Lists.newArrayList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0)); + JavaSchemaRDD training = + jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // Fit the pipeline to training documents. + PipelineModel model = pipeline.fit(training); + + // Prepare test documents, which are unlabeled. + List localTest = Lists.newArrayList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop")); + JavaSchemaRDD test = + jsql.applySchema(jsc.parallelize(localTest), Document.class); + + // Make predictions on test documents. + model.transform(test).registerAsTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + for (Row r: predictions.collect()) { + System.out.println(r); + } + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java similarity index 88% rename from examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java rename to examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java index 1af2067b2b929..4a5ac404ea5ea 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java @@ -27,18 +27,18 @@ import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.GradientBoostedTrees; import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; import org.apache.spark.mllib.util.MLUtils; /** * Classification and regression using gradient-boosted decision trees. */ -public final class JavaGradientBoostedTrees { +public final class JavaGradientBoostedTreesRunner { private static void usage() { - System.err.println("Usage: JavaGradientBoostedTrees " + + System.err.println("Usage: JavaGradientBoostedTreesRunner " + " "); System.exit(-1); } @@ -55,7 +55,7 @@ public static void main(String[] args) { if (args.length > 2) { usage(); } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); @@ -64,7 +64,7 @@ public static void main(String[] args) { // Note: All features are treated as continuous. BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); boostingStrategy.setNumIterations(10); - boostingStrategy.weakLearnerParams().setMaxDepth(5); + boostingStrategy.treeStrategy().setMaxDepth(5); if (algo.equals("Classification")) { // Compute the number of classes from the data. @@ -73,10 +73,10 @@ public static void main(String[] args) { return p.label(); } }).countByValue().size(); - boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses); // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); // Evaluate model on training instances and compute training error JavaPairRDD predictionAndLabel = @@ -95,7 +95,7 @@ public static void main(String[] args) { System.out.println("Learned classification tree model:\n" + model); } else if (algo.equals("Regression")) { // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); // Evaluate model on training instances and compute training error JavaPairRDD predictionAndLabel = diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala new file mode 100644 index 0000000000000..ee7897d9062d9 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.beans.BeanInfo + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.sql.SQLContext + +@BeanInfo +case class LabeledDocument(id: Long, text: String, label: Double) + +@BeanInfo +case class Document(id: Long, text: String) + +/** + * A simple text classification pipeline that recognizes "spark" from input text. This is to show + * how to create and configure an ML pipeline. Run with + * {{{ + * bin/run-example ml.SimpleTextClassificationPipeline + * }}} + */ +object SimpleTextClassificationPipeline { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Prepare training documents, which are labeled. + val training = sparkContext.parallelize(Seq( + LabeledDocument(0L, "a b c d e spark", 1.0), + LabeledDocument(1L, "b d", 0.0), + LabeledDocument(2L, "spark f g h", 1.0), + LabeledDocument(3L, "hadoop mapreduce", 0.0))) + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // Fit the pipeline to training documents. + val model = pipeline.fit(training) + + // Prepare test documents, which are unlabeled. + val test = sparkContext.parallelize(Seq( + Document(4L, "spark i j k"), + Document(5L, "l m n"), + Document(6L, "mapreduce spark"), + Document(7L, "apache hadoop"))) + + // Make predictions on test documents. + model.transform(test) + .select('id, 'text, 'score, 'prediction) + .collect() + .foreach(println) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 1edd2432a0352..a113653810b93 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -55,7 +55,7 @@ object BinaryClassification { stepSize: Double = 1.0, algorithm: Algorithm = LR, regType: RegType = L2, - regParam: Double = 0.1) extends AbstractParams[Params] + regParam: Double = 0.01) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 63f02cf7b98b9..98f9d1689c8e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,11 +22,11 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} +import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -349,24 +349,14 @@ object DecisionTreeRunner { sc.stop() } - /** - * Calculates the mean squared error for regression. - */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = tree.predict(y.features) - y.label - err * err - }.mean() - } - /** * Calculates the mean squared error for regression. */ private[mllib] def meanSquaredError( - tree: WeightedEnsembleModel, + model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { data.map { y => - val err = tree.predict(y.features) - y.label + val err = model.predict(y.features) - y.label err * err }.mean() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala similarity index 91% rename from examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala rename to examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 9b6db01448be0..1def8b45a230c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -21,21 +21,21 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} import org.apache.spark.util.Utils /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * ./bin/run-example mllib.GradientBoostedTreesRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. * * Note: This script treats all features as real-valued (not categorical). * To include categorical features, modify categoricalFeaturesInfo. */ -object GradientBoostedTrees { +object GradientBoostedTreesRunner { case class Params( input: String = null, @@ -93,24 +93,24 @@ object GradientBoostedTrees { def run(params: Params) { - val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params") val sc = new SparkContext(conf) - println(s"GradientBoostedTrees with parameters:\n$params") + println(s"GradientBoostedTreesRunner with parameters:\n$params") // Load training and test data and cache it. val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) val boostingStrategy = BoostingStrategy.defaultParams(params.algo) - boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.treeStrategy.numClassesForClassification = numClasses boostingStrategy.numIterations = params.numIterations - boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + boostingStrategy.treeStrategy.maxDepth = params.maxDepth val randomSeed = Utils.random.nextInt() if (params.algo == "Classification") { val startTime = System.nanoTime() - val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val model = GradientBoostedTrees.train(training, boostingStrategy) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { @@ -127,7 +127,7 @@ object GradientBoostedTrees { println(s"Test accuracy = $testAccuracy") } else if (params.algo == "Regression") { val startTime = System.nanoTime() - val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val model = GradientBoostedTrees.train(training, boostingStrategy) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index e1f9622350135..6a456ba7ec07b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.optimization.{SimpleUpdater, SquaredL2Updater, L1U * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt`. * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object LinearRegression extends App { +object LinearRegression { object RegType extends Enumeration { type RegType = Value @@ -47,42 +47,44 @@ object LinearRegression extends App { numIterations: Int = 100, stepSize: Double = 1.0, regType: RegType = L2, - regParam: Double = 0.1) extends AbstractParams[Params] - - val defaultParams = Params() - - val parser = new OptionParser[Params]("LinearRegression") { - head("LinearRegression: an example app for linear regression.") - opt[Int]("numIterations") - .text("number of iterations") - .action((x, c) => c.copy(numIterations = x)) - opt[Double]("stepSize") - .text(s"initial step size, default: ${defaultParams.stepSize}") - .action((x, c) => c.copy(stepSize = x)) - opt[String]("regType") - .text(s"regularization type (${RegType.values.mkString(",")}), " + - s"default: ${defaultParams.regType}") - .action((x, c) => c.copy(regType = RegType.withName(x))) - opt[Double]("regParam") - .text(s"regularization parameter, default: ${defaultParams.regParam}") - arg[String]("") - .required() - .text("input paths to labeled examples in LIBSVM format") - .action((x, c) => c.copy(input = x)) - note( - """ - |For example, the following command runs this app on a synthetic dataset: - | - | bin/spark-submit --class org.apache.spark.examples.mllib.LinearRegression \ - | examples/target/scala-*/spark-examples-*.jar \ - | data/mllib/sample_linear_regression_data.txt - """.stripMargin) - } + regParam: Double = 0.01) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegression") { + head("LinearRegression: an example app for linear regression.") + opt[Int]("numIterations") + .text("number of iterations") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("stepSize") + .text(s"initial step size, default: ${defaultParams.stepSize}") + .action((x, c) => c.copy(stepSize = x)) + opt[String]("regType") + .text(s"regularization type (${RegType.values.mkString(",")}), " + + s"default: ${defaultParams.regType}") + .action((x, c) => c.copy(regType = RegType.withName(x))) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + arg[String]("") + .required() + .text("input paths to labeled examples in LIBSVM format") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app on a synthetic dataset: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.LinearRegression \ + | examples/target/scala-*/spark-examples-*.jar \ + | data/mllib/sample_linear_regression_data.txt + """.stripMargin) + } - parser.parse(args, defaultParams).map { params => - run(params) - } getOrElse { - sys.exit(1) + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } } def run(params: Params) { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 0c52ef8ed96ac..227acc117502d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -27,6 +27,7 @@ object HiveFromSpark { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("HiveFromSpark") val sc = new SparkContext(sparkConf) + val path = s"${System.getenv("SPARK_HOME")}/examples/src/main/resources/kv1.txt" // A local hive context creates an instance of the Hive Metastore in process, storing // the warehouse data in the current directory. This location can be overridden by @@ -35,7 +36,7 @@ object HiveFromSpark { import hiveContext._ sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index a4d159bf38377..ed186ea5650c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -18,12 +18,13 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf +import org.apache.spark.HashPartitioner import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every - * second. + * second starting with initial value of word count. * Usage: StatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. @@ -51,12 +52,19 @@ object StatefulNetworkWordCount { Some(currentCount + previousCount) } + val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + } + val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Create a NetworkInputDStream on target ip:port and count the + // Initial RDD input to updateStateByKey + val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) + + // Create a ReceiverInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') val lines = ssc.socketTextStream(args(0), args(1).toInt) val words = lines.flatMap(_.split(" ")) @@ -64,7 +72,8 @@ object StatefulNetworkWordCount { // Update the cumulative count using updateStateByKey // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) + val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, + new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index d9b886eff77cc..55226c0a6df60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -50,7 +50,7 @@ object PageViewStream { val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) - // Create a NetworkInputDStream on target host:port and convert each line to a PageView + // Create a ReceiverInputDStream on target host:port and convert each line to a PageView val pageViews = ssc.socketTextStream(host, port) .flatMap(_.split("\n")) .map(PageView.fromString(_)) diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index ac291bd4fde20..72618b6515f83 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..4411d6e20c52a --- /dev/null +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file streaming/target/unit-tests.log +log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN + diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index a2b2cc6149d95..650b2fbe1c142 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -159,6 +159,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) channelContext.putAll(overrides) + channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) val sink = new SparkSink() diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 7d31e32283d88..a682f0e8471d8 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,19 +39,13 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} + provided
    org.apache.spark spark-streaming-flume-sink_${scala.binary.version} ${project.version} - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.flume flume-ng-sdk diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala new file mode 100644 index 0000000000000..1a900007b696b --- /dev/null +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.{IOException, ObjectInputStream} + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} +import org.apache.spark.util.Utils + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * + * The buffer contains a sequence of RDD's, each containing a sequence of items + */ +class TestOutputStream[T: ClassTag](parent: DStream[T], + val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]()) + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + output.clear() + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 475026e8eb140..b57a1c71e35b9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -20,9 +20,6 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} -import java.util.Random - -import org.apache.spark.TestUtils import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -32,20 +29,35 @@ import org.apache.flume.channel.MemoryChannel import org.apache.flume.conf.Configurables import org.apache.flume.event.EventBuilder +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} +import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.Utils -class FlumePollingStreamSuite extends TestSuiteBase { +class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { val batchCount = 5 val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch val channelCapacity = 5000 val maxAttempts = 5 + val batchDuration = Seconds(1) + + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName(this.getClass.getSimpleName) + + def beforeFunction() { + logInfo("Using manual clock") + conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + } + + before(beforeFunction()) test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -229,4 +241,5 @@ class FlumePollingStreamSuite extends TestSuiteBase { null } } + } diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 2067c473f0e3f..b3f44471cd326 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.apache.kafka diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index e20e2c8f26991..4d26b640e8d74 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -17,23 +17,21 @@ package org.apache.spark.streaming.kafka +import java.util.Properties + import scala.collection.Map import scala.reflect.{classTag, ClassTag} -import java.util.Properties -import java.util.concurrent.Executors - -import kafka.consumer._ +import kafka.consumer.{KafkaStream, Consumer, ConsumerConfig, ConsumerConnector} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient._ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils /** * Input stream that pulls messages from a Kafka Broker. @@ -53,12 +51,16 @@ class KafkaInputDStream[ @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], + useReliableReceiver: Boolean, storageLevel: StorageLevel ) extends ReceiverInputDStream[(K, V)](ssc_) with Logging { def getReceiver(): Receiver[(K, V)] = { - new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) - .asInstanceOf[Receiver[(K, V)]] + if (!useReliableReceiver) { + new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) + } else { + new ReliableKafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) + } } } @@ -71,14 +73,15 @@ class KafkaReceiver[ kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel - ) extends Receiver[Any](storageLevel) with Logging { + ) extends Receiver[(K, V)](storageLevel) with Logging { // Connection to Kafka - var consumerConnector : ConsumerConnector = null + var consumerConnector: ConsumerConnector = null def onStop() { if (consumerConnector != null) { consumerConnector.shutdown() + consumerConnector = null } } @@ -97,12 +100,6 @@ class KafkaReceiver[ consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + zkConnect) - // When auto.offset.reset is defined, it is our responsibility to try and whack the - // consumer group zk node. - if (kafkaParams.contains("auto.offset.reset")) { - tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) - } - val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[K]] @@ -110,11 +107,11 @@ class KafkaReceiver[ .newInstance(consumerConfig.props) .asInstanceOf[Decoder[V]] - // Create Threads for each Topic/Message Stream we are listening + // Create threads for each topic/message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - val executorPool = Executors.newFixedThreadPool(topics.values.sum) + val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") try { // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => @@ -125,13 +122,15 @@ class KafkaReceiver[ } } - // Handles Kafka Messages - private class MessageHandler[K: ClassTag, V: ClassTag](stream: KafkaStream[K, V]) + // Handles Kafka messages + private class MessageHandler(stream: KafkaStream[K, V]) extends Runnable { def run() { logInfo("Starting MessageHandler.") try { - for (msgAndMetadata <- stream) { + val streamIterator = stream.iterator() + while (streamIterator.hasNext()) { + val msgAndMetadata = streamIterator.next() store((msgAndMetadata.key, msgAndMetadata.message)) } } catch { @@ -139,26 +138,4 @@ class KafkaReceiver[ } } } - - // It is our responsibility to delete the consumer group when specifying auto.offset.reset. This - // is because Kafka 0.7.2 only honors this param when the group is not in zookeeper. - // - // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied - // from Kafka's ConsoleConsumer. See code related to 'auto.offset.reset' when it is set to - // 'smallest'/'largest': - // scalastyle:off - // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala - // scalastyle:on - private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { - val dir = "/consumers/" + groupId - logInfo("Cleaning up temporary Zookeeper data under " + dir + ".") - val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) - try { - zk.deleteRecursive(dir) - } catch { - case e: Throwable => logWarning("Error cleaning up temporary Zookeeper data", e) - } finally { - zk.close() - } - } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 48668f763e41e..b4ac929e0c070 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,19 +17,18 @@ package org.apache.spark.streaming.kafka -import scala.reflect.ClassTag -import scala.collection.JavaConversions._ - import java.lang.{Integer => JInt} import java.util.{Map => JMap} +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + import kafka.serializer.{Decoder, StringDecoder} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext, JavaPairDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} - +import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream object KafkaUtils { /** @@ -71,7 +70,8 @@ object KafkaUtils { topics: Map[String, Int], storageLevel: StorageLevel ): ReceiverInputDStream[(K, V)] = { - new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, storageLevel) + val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false) + new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel) } /** @@ -100,7 +100,6 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. - * */ def createStream( jssc: JavaStreamingContext, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala new file mode 100644 index 0000000000000..be734b80272d1 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.util.Properties +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap} + +import scala.collection.{Map, mutable} +import scala.reflect.{ClassTag, classTag} + +import kafka.common.TopicAndPartition +import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} +import kafka.message.MessageAndMetadata +import kafka.serializer.Decoder +import kafka.utils.{VerifiableProperties, ZKGroupTopicDirs, ZKStringSerializer, ZkUtils} +import org.I0Itec.zkclient.ZkClient + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} +import org.apache.spark.util.Utils + +/** + * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss. + * It is turned off by default and will be enabled when + * spark.streaming.receiver.writeAheadLog.enable is true. The difference compared to KafkaReceiver + * is that this receiver manages topic-partition/offset itself and updates the offset information + * after data is reliably stored as write-ahead log. Offsets will only be updated when data is + * reliably stored, so the potential data loss problem of KafkaReceiver can be eliminated. + * + * Note: ReliableKafkaReceiver will set auto.commit.enable to false to turn off automatic offset + * commit mechanism in Kafka consumer. So setting this configuration manually within kafkaParams + * will not take effect. + */ +private[streaming] +class ReliableKafkaReceiver[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + storageLevel: StorageLevel) + extends Receiver[(K, V)](storageLevel) with Logging { + + private val groupId = kafkaParams("group.id") + private val AUTO_OFFSET_COMMIT = "auto.commit.enable" + private def conf = SparkEnv.get.conf + + /** High level consumer to connect to Kafka. */ + private var consumerConnector: ConsumerConnector = null + + /** zkClient to connect to Zookeeper to commit the offsets. */ + private var zkClient: ZkClient = null + + /** + * A HashMap to manage the offset for each topic/partition, this HashMap is called in + * synchronized block, so mutable HashMap will not meet concurrency issue. + */ + private var topicPartitionOffsetMap: mutable.HashMap[TopicAndPartition, Long] = null + + /** A concurrent HashMap to store the stream block id and related offset snapshot. */ + private var blockOffsetMap: ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]] = null + + /** + * Manage the BlockGenerator in receiver itself for better managing block store and offset + * commit. + */ + private var blockGenerator: BlockGenerator = null + + /** Thread pool running the handlers for receiving message from multiple topics and partitions. */ + private var messageHandlerThreadPool: ThreadPoolExecutor = null + + override def onStart(): Unit = { + logInfo(s"Starting Kafka Consumer Stream with group: $groupId") + + // Initialize the topic-partition / offset hash map. + topicPartitionOffsetMap = new mutable.HashMap[TopicAndPartition, Long] + + // Initialize the stream block id / offset snapshot hash map. + blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() + + // Initialize the block generator for storing Kafka message. + blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + + if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { + logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + + "otherwise we will manually set it to false to turn off auto offset commit in Kafka") + } + + val props = new Properties() + kafkaParams.foreach(param => props.put(param._1, param._2)) + // Manually set "auto.commit.enable" to "false" no matter user explicitly set it to true, + // we have to make sure this property is set to false to turn off auto commit mechanism in + // Kafka. + props.setProperty(AUTO_OFFSET_COMMIT, "false") + + val consumerConfig = new ConsumerConfig(props) + + assert(!consumerConfig.autoCommitEnable) + + logInfo(s"Connecting to Zookeeper: ${consumerConfig.zkConnect}") + consumerConnector = Consumer.create(consumerConfig) + logInfo(s"Connected to Zookeeper: ${consumerConfig.zkConnect}") + + zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs, + consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer) + + messageHandlerThreadPool = Utils.newDaemonFixedThreadPool( + topics.values.sum, "KafkaMessageHandler") + + blockGenerator.start() + + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(consumerConfig.props) + .asInstanceOf[Decoder[K]] + + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(consumerConfig.props) + .asInstanceOf[Decoder[V]] + + val topicMessageStreams = consumerConnector.createMessageStreams( + topics, keyDecoder, valueDecoder) + + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => + messageHandlerThreadPool.submit(new MessageHandler(stream)) + } + } + } + + override def onStop(): Unit = { + if (messageHandlerThreadPool != null) { + messageHandlerThreadPool.shutdown() + messageHandlerThreadPool = null + } + + if (consumerConnector != null) { + consumerConnector.shutdown() + consumerConnector = null + } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (blockGenerator != null) { + blockGenerator.stop() + blockGenerator = null + } + + if (topicPartitionOffsetMap != null) { + topicPartitionOffsetMap.clear() + topicPartitionOffsetMap = null + } + + if (blockOffsetMap != null) { + blockOffsetMap.clear() + blockOffsetMap = null + } + } + + /** Store a Kafka message and the associated metadata as a tuple. */ + private def storeMessageAndMetadata( + msgAndMetadata: MessageAndMetadata[K, V]): Unit = { + val topicAndPartition = TopicAndPartition(msgAndMetadata.topic, msgAndMetadata.partition) + val data = (msgAndMetadata.key, msgAndMetadata.message) + val metadata = (topicAndPartition, msgAndMetadata.offset) + blockGenerator.addDataWithCallback(data, metadata) + } + + /** Update stored offset */ + private def updateOffset(topicAndPartition: TopicAndPartition, offset: Long): Unit = { + topicPartitionOffsetMap.put(topicAndPartition, offset) + } + + /** + * Remember the current offsets for each topic and partition. This is called when a block is + * generated. + */ + private def rememberBlockOffsets(blockId: StreamBlockId): Unit = { + // Get a snapshot of current offset map and store with related block id. + val offsetSnapshot = topicPartitionOffsetMap.toMap + blockOffsetMap.put(blockId, offsetSnapshot) + topicPartitionOffsetMap.clear() + } + + /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */ + private def storeBlockAndCommitOffset( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) + Option(blockOffsetMap.get(blockId)).foreach(commitOffset) + blockOffsetMap.remove(blockId) + } + + /** + * Commit the offset of Kafka's topic/partition, the commit mechanism follow Kafka 0.8.x's + * metadata schema in Zookeeper. + */ + private def commitOffset(offsetMap: Map[TopicAndPartition, Long]): Unit = { + if (zkClient == null) { + val thrown = new IllegalStateException("Zookeeper client is unexpectedly null") + stop("Zookeeper client is not initialized before commit offsets to ZK", thrown) + return + } + + for ((topicAndPart, offset) <- offsetMap) { + try { + val topicDirs = new ZKGroupTopicDirs(groupId, topicAndPart.topic) + val zkPath = s"${topicDirs.consumerOffsetDir}/${topicAndPart.partition}" + + ZkUtils.updatePersistentPath(zkClient, zkPath, offset.toString) + } catch { + case e: Exception => + logWarning(s"Exception during commit offset $offset for topic" + + s"${topicAndPart.topic}, partition ${topicAndPart.partition}", e) + } + + logInfo(s"Committed offset $offset for topic ${topicAndPart.topic}, " + + s"partition ${topicAndPart.partition}") + } + } + + /** Class to handle received Kafka message. */ + private final class MessageHandler(stream: KafkaStream[K, V]) extends Runnable { + override def run(): Unit = { + while (!isStopped) { + try { + val streamIterator = stream.iterator() + while (streamIterator.hasNext) { + storeMessageAndMetadata(streamIterator.next) + } + } catch { + case e: Exception => + logError("Error handling message", e) + } + } + } + } + + /** Class to handle blocks generated by the block generator. */ + private final class GeneratedBlockHandler extends BlockGeneratorListener { + + def onAddData(data: Any, metadata: Any): Unit = { + // Update the offset of the data that was added to the generator + if (metadata != null) { + val (topicAndPartition, offset) = metadata.asInstanceOf[(TopicAndPartition, Long)] + updateOffset(topicAndPartition, offset) + } + } + + def onGenerateBlock(blockId: StreamBlockId): Unit = { + // Remember the offsets of topics/partitions when a block has been generated + rememberBlockOffsets(blockId) + } + + def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + // Store block and commit the blocks offset + storeBlockAndCommitOffset(blockId, arrayBuffer) + } + + def onError(message: String, throwable: Throwable): Unit = { + reportError(message, throwable) + } + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index efb0099c7c850..6e1abf3f385ee 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -20,7 +20,10 @@ import java.io.Serializable; import java.util.HashMap; import java.util.List; +import java.util.Random; +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.Duration; import scala.Predef; import scala.Tuple2; import scala.collection.JavaConverters; @@ -32,8 +35,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.LocalJavaStreamingContext; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -42,25 +43,27 @@ import org.junit.After; import org.junit.Before; -public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable { - private transient KafkaStreamSuite testSuite = new KafkaStreamSuite(); +public class JavaKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient Random random = new Random(); + private transient KafkaStreamSuiteBase suiteBase = null; @Before - @Override public void setUp() { - testSuite.beforeFunction(); + suiteBase = new KafkaStreamSuiteBase() { }; + suiteBase.setupKafka(); System.clearProperty("spark.driver.port"); - //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, new Duration(500)); } @After - @Override public void tearDown() { ssc.stop(); ssc = null; System.clearProperty("spark.driver.port"); - testSuite.afterFunction(); + suiteBase.tearDownKafka(); } @Test @@ -74,15 +77,15 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - testSuite.createTopic(topic); + suiteBase.createTopic(topic); HashMap tmp = new HashMap(sent); - testSuite.produceAndSendMessage(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms())); + suiteBase.produceAndSendMessage(topic, + JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( + Predef.>conforms())); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", testSuite.zkHost() + ":" + testSuite.zkPort()); - kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); + kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); JavaPairDStream stream = KafkaUtils.createStream(ssc, @@ -124,11 +127,16 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); - ssc.awaitTermination(3000); - + long startTime = System.currentTimeMillis(); + boolean sizeMatches = false; + while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { + sizeMatches = sent.size() == result.size(); + Thread.sleep(200); + } Assert.assertEquals(sent.size(), result.size()); for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } + ssc.stop(); } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 6943326eb750e..b19c053ebfc44 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -19,51 +19,57 @@ package org.apache.spark.streaming.kafka import java.io.File import java.net.InetSocketAddress -import java.util.{Properties, Random} +import java.util.Properties import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random import kafka.admin.CreateTopicCommand import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, ProducerConfig, Producer} -import kafka.utils.ZKStringSerializer +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.{StringDecoder, StringEncoder} import kafka.server.{KafkaConfig, KafkaServer} - +import kafka.utils.ZKStringSerializer import org.I0Itec.zkclient.ZkClient +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually -import org.apache.zookeeper.server.ZooKeeperServer -import org.apache.zookeeper.server.NIOServerCnxnFactory - -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class KafkaStreamSuite extends TestSuiteBase { - import KafkaTestUtils._ - - val zkHost = "localhost" - var zkPort: Int = 0 - val zkConnectionTimeout = 6000 - val zkSessionTimeout = 6000 - - protected var brokerPort = 9092 - protected var brokerConf: KafkaConfig = _ - protected var zookeeper: EmbeddedZookeeper = _ - protected var zkClient: ZkClient = _ - protected var server: KafkaServer = _ - protected var producer: Producer[String, String] = _ - - override def useManualClock = false - - override def beforeFunction() { +/** + * This is an abstract base class for Kafka testsuites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + */ +abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { + + var zkAddress: String = _ + var zkClient: ZkClient = _ + + private val zkHost = "localhost" + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + private var zookeeper: EmbeddedZookeeper = _ + private var zkPort: Int = 0 + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + private var server: KafkaServer = _ + private var producer: Producer[String, String] = _ + + def setupKafka() { // Zookeeper server startup zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") // Get the actual zookeeper binding port zkPort = zookeeper.actualPort + zkAddress = s"$zkHost:$zkPort" logInfo("==================== 0 ====================") - zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) logInfo("==================== 1 ====================") @@ -71,7 +77,7 @@ class KafkaStreamSuite extends TestSuiteBase { var bindSuccess: Boolean = false while(!bindSuccess) { try { - val brokerProps = getBrokerConfig(brokerPort, s"$zkHost:$zkPort") + val brokerProps = getBrokerConfig() brokerConf = new KafkaConfig(brokerProps) server = new KafkaServer(brokerConf) logInfo("==================== 2 ====================") @@ -89,53 +95,30 @@ class KafkaStreamSuite extends TestSuiteBase { Thread.sleep(2000) logInfo("==================== 4 ====================") - super.beforeFunction() } - override def afterFunction() { - producer.close() - server.shutdown() - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - - zkClient.close() - zookeeper.shutdown() - - super.afterFunction() - } - - test("Kafka input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val topic = "topic1" - val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - produceAndSendMessage(topic, sent) + def tearDownKafka() { + if (producer != null) { + producer.close() + producer = null + } - val kafkaParams = Map("zookeeper.connect" -> s"$zkHost:$zkPort", - "group.id" -> s"test-consumer-${random.nextInt(10000)}", - "auto.offset.reset" -> "smallest") + if (server != null) { + server.shutdown() + server = null + } - val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, - kafkaParams, - Map(topic -> 1), - StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() - stream.map { case (k, v) => v } - .countByValue() - .foreachRDD { r => - val ret = r.collect() - ret.toMap.foreach { kv => - val count = result.getOrElseUpdate(kv._1, 0) + kv._2 - result.put(kv._1, count) - } - } - ssc.start() - ssc.awaitTermination(3000) + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - assert(sent.size === result.size) - sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } + if (zkClient != null) { + zkClient.close() + zkClient = null + } - ssc.stop() + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } } private def createTestMessage(topic: String, sent: Map[String, Int]) @@ -150,58 +133,43 @@ class KafkaStreamSuite extends TestSuiteBase { CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") logInfo("==================== 5 ====================") // wait until metadata is propagated - waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000) + waitUntilMetadataIsPropagated(topic, 0) } def produceAndSendMessage(topic: String, sent: Map[String, Int]) { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr))) + producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) producer.send(createTestMessage(topic, sent): _*) + producer.close() logInfo("==================== 6 ====================") } -} - -object KafkaTestUtils { - val random = new Random() - def getBrokerConfig(port: Int, zkConnect: String): Properties = { + private def getBrokerConfig(): Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") - props.put("port", port.toString) + props.put("port", brokerPort.toString) props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkConnect) + props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props } - def getProducerConfig(brokerList: String): Properties = { + private def getProducerConfig(): Properties = { + val brokerAddr = brokerConf.hostName + ":" + brokerConf.port val props = new Properties() - props.put("metadata.broker.list", brokerList) + props.put("metadata.broker.list", brokerAddr) props.put("serializer.class", classOf[StringEncoder].getName) props } - def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = { - val startTime = System.currentTimeMillis() - while (true) { - if (condition()) - return true - if (System.currentTimeMillis() > startTime + waitTime) - return false - Thread.sleep(waitTime.min(100L)) + private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { + eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { + assert( + server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), + s"Partition [$topic, $partition] metadata not propagated after timeout" + ) } - // Should never go to here - throw new RuntimeException("unexpected error") - } - - def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int, - timeout: Long) { - assert(waitUntilTrue(() => - servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains( - TopicAndPartition(topic, partition))), timeout), - s"Partition [$topic, $partition] metadata not propagated after timeout") } class EmbeddedZookeeper(val zkConnect: String) { @@ -227,3 +195,53 @@ object KafkaTestUtils { } } } + + +class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { + var ssc: StreamingContext = _ + + before { + setupKafka() + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + tearDownKafka() + } + + test("Kafka input stream") { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + produceAndSendMessage(topic, sent) + + val kafkaParams = Map("zookeeper.connect" -> zkAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}", + "auto.offset.reset" -> "smallest") + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map(_._2).countByValue().foreachRDD { r => + val ret = r.collect() + ret.toMap.foreach { kv => + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } + } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + assert(sent.size === result.size) + sent.keys.foreach { k => + assert(sent(k) === result(k).toInt) + } + } + ssc.stop() + } +} + diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala new file mode 100644 index 0000000000000..64ccc92c81fa9 --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + + +import java.io.File + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.google.common.io.Files +import kafka.serializer.StringDecoder +import kafka.utils.{ZKGroupTopicDirs, ZkUtils} +import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually + +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} + +class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { + + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + + + var groupId: String = _ + var kafkaParams: Map[String, String] = _ + var ssc: StreamingContext = _ + var tempDirectory: File = null + + before { + setupKafka() + groupId = s"test-consumer-${Random.nextInt(10000)}" + kafkaParams = Map( + "zookeeper.connect" -> zkAddress, + "group.id" -> groupId, + "auto.offset.reset" -> "smallest" + ) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + tempDirectory = Files.createTempDir() + ssc.checkpoint(tempDirectory.getAbsolutePath) + } + + after { + if (ssc != null) { + ssc.stop() + } + if (tempDirectory != null && tempDirectory.exists()) { + FileUtils.deleteDirectory(tempDirectory) + tempDirectory = null + } + tearDownKafka() + } + + + test("Reliable Kafka input stream with single topic") { + var topic = "test-topic" + createTopic(topic) + produceAndSendMessage(topic, data) + + // Verify whether the offset of this group/topic/partition is 0 before starting. + assert(getCommitOffset(groupId, topic, 0) === None) + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map { case (k, v) => v }.foreachRDD { r => + val ret = r.collect() + ret.foreach { v => + val count = result.getOrElseUpdate(v, 0) + 1 + result.put(v, count) + } + } + ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { + // A basic process verification for ReliableKafkaReceiver. + // Verify whether received message number is equal to the sent message number. + assert(data.size === result.size) + // Verify whether each message is the same as the data to be verified. + data.keys.foreach { k => assert(data(k) === result(k).toInt) } + // Verify the offset number whether it is equal to the total message number. + assert(getCommitOffset(groupId, topic, 0) === Some(29L)) + } + ssc.stop() + } + + test("Reliable Kafka input stream with multiple topics") { + val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) + topics.foreach { case (t, _) => + createTopic(t) + produceAndSendMessage(t, data) + } + + // Before started, verify all the group/topic/partition offsets are 0. + topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === None) } + + // Consuming all the data sent to the broker which will potential commit the offsets internally. + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) + stream.foreachRDD(_ => Unit) + ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { + // Verify the offset for each group/topic to see whether they are equal to the expected one. + topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } + } + ssc.stop() + } + + + /** Getting partition offset from Zookeeper. */ + private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { + assert(zkClient != null, "Zookeeper client is not initialized") + val topicDirs = new ZKGroupTopicDirs(groupId, topic) + val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" + ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + } +} diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 371f1f1e9d39a..9025915f4447e 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,23 +39,12 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.eclipse.paho - mqtt-client - 0.4.0 - - - ${akka.group} - akka-zeromq_${scala.binary.version} - ${akka.version} + org.eclipse.paho.client.mqttv3 + 1.0.1 org.scalatest diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 467fd263e2d64..84595acf45ccb 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,11 +17,19 @@ package org.apache.spark.streaming.mqtt -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.scalatest.FunSuite + +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class MQTTStreamSuite extends TestSuiteBase { +class MQTTStreamSuite extends FunSuite { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("mqtt input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 1d7dd49d15c22..000ace1446e5e 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.twitter4j diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 93741e0375164..9ee57d7581d85 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -17,13 +17,23 @@ package org.apache.spark.streaming.twitter -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} -import org.apache.spark.storage.StorageLevel + +import org.scalatest.{BeforeAndAfter, FunSuite} +import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} + +import org.apache.spark.Logging +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -import twitter4j.Status -class TwitterStreamSuite extends TestSuiteBase { +class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("twitter input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7e48968feb3bc..29c452093502e 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided ${akka.group} diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index cc10ff6ae03cd..a7566e733d891 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,12 +20,19 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe +import org.scalatest.FunSuite import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends TestSuiteBase { +class ZeroMQStreamSuite extends FunSuite { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("zeromq input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 7e478bed62da7..c8477a6566311 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 560244ad93369..c0d3a61119113 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 71a078d58a8d8..d1427f6a0c6e9 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 3f49b1d63b6e1..9982b36f9b62f 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala new file mode 100644 index 0000000000000..f70715fca6eea --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +/** + * Represents an edge along with its neighboring vertices and allows sending messages along the + * edge. Used in [[Graph#aggregateMessages]]. + */ +abstract class EdgeContext[VD, ED, A] { + /** The vertex id of the edge's source vertex. */ + def srcId: VertexId + /** The vertex id of the edge's destination vertex. */ + def dstId: VertexId + /** The vertex attribute of the edge's source vertex. */ + def srcAttr: VD + /** The vertex attribute of the edge's destination vertex. */ + def dstAttr: VD + /** The attribute associated with the edge. */ + def attr: ED + + /** Sends a message to the source vertex. */ + def sendToSrc(msg: A): Unit + /** Sends a message to the destination vertex. */ + def sendToDst(msg: A): Unit + + /** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */ + def toEdgeTriplet: EdgeTriplet[VD, ED] = { + val et = new EdgeTriplet[VD, ED] + et.srcId = srcId + et.srcAttr = srcAttr + et.dstId = dstId + et.dstAttr = dstAttr + et.attr = attr + et + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 5267560b3e5ce..cc70b396a8dd4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -17,14 +17,19 @@ package org.apache.spark.graphx -import scala.reflect.{classTag, ClassTag} +import scala.language.existentials +import scala.reflect.ClassTag -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.Dependency +import org.apache.spark.Partition +import org.apache.spark.SparkContext +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.graphx.impl.EdgePartitionBuilder +import org.apache.spark.graphx.impl.EdgeRDDImpl /** * `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each @@ -32,33 +37,16 @@ import org.apache.spark.graphx.impl.EdgePartitionBuilder * edge to provide the triplet view. Shipping of the vertex attributes is managed by * `impl.ReplicatedVertexView`. */ -class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( - val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], - val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) - extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { +abstract class EdgeRDD[ED]( + @transient sc: SparkContext, + @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { - override def setName(_name: String): this.type = { - if (partitionsRDD.name != null) { - partitionsRDD.setName(partitionsRDD.name + ", " + _name) - } else { - partitionsRDD.setName(_name) - } - this - } - setName("EdgeRDD") + private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } override protected def getPartitions: Array[Partition] = partitionsRDD.partitions - /** - * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the - * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new - * partitioner that allows co-partitioning with `partitionsRDD`. - */ - override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) - override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { - val p = firstParent[(PartitionID, EdgePartition[ED, VD])].iterator(part, context) + val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { p.next._2.iterator.map(_.copy()) } else { @@ -66,45 +54,6 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( } } - override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() - - /** - * Persists the edge partitions at the specified storage level, ignoring any existing target - * storage level. - */ - override def persist(newLevel: StorageLevel): this.type = { - partitionsRDD.persist(newLevel) - this - } - - override def unpersist(blocking: Boolean = true): this.type = { - partitionsRDD.unpersist(blocking) - this - } - - /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ - override def cache(): this.type = { - partitionsRDD.persist(targetStorageLevel) - this - } - - /** The number of edges in the RDD. */ - override def count(): Long = { - partitionsRDD.map(_._2.size.toLong).reduce(_ + _) - } - - private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( - f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = { - this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => - if (iter.hasNext) { - val (pid, ep) = iter.next() - Iterator(Tuple2(pid, f(pid, ep))) - } else { - Iterator.empty - } - }, preservesPartitioning = true)) - } - /** * Map the values in an edge partitioning preserving the structure but changing the values. * @@ -112,22 +61,14 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * @param f the function from an edge to a new edge value * @return a new EdgeRDD containing the new edge values */ - def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] = - mapEdgePartitions((pid, part) => part.map(f)) + def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2] /** * Reverse all the edges in this RDD. * * @return a new EdgeRDD containing all the edges reversed */ - def reverse: EdgeRDD[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) - - /** Removes all edges but those matching `epred` and where both vertices match `vpred`. */ - def filter( - epred: EdgeTriplet[VD, ED] => Boolean, - vpred: (VertexId, VD) => Boolean): EdgeRDD[ED, VD] = { - mapEdgePartitions((pid, part) => part.filter(epred, vpred)) - } + def reverse: EdgeRDD[ED] /** * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same @@ -139,23 +80,8 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * with values supplied by `f` */ def innerJoin[ED2: ClassTag, ED3: ClassTag] - (other: EdgeRDD[ED2, _]) - (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] = { - val ed2Tag = classTag[ED2] - val ed3Tag = classTag[ED3] - this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { - (thisIter, otherIter) => - val (pid, thisEPart) = thisIter.next() - val (_, otherEPart) = otherIter.next() - Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag))) - }) - } - - /** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */ - private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( - partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDD[ED2, VD2] = { - new EdgeRDD(partitionsRDD, this.targetStorageLevel) - } + (other: EdgeRDD[ED2]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] /** * Changes the target storage level while preserving all other properties of the @@ -164,11 +90,7 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * This does not actually trigger a cache; to do this, call * [[org.apache.spark.graphx.EdgeRDD#cache]] on the returned EdgeRDD. */ - private[graphx] def withTargetStorageLevel( - targetStorageLevel: StorageLevel): EdgeRDD[ED, VD] = { - new EdgeRDD(this.partitionsRDD, targetStorageLevel) - } - + private[graphx] def withTargetStorageLevel(targetStorageLevel: StorageLevel): EdgeRDD[ED] } object EdgeRDD { @@ -178,7 +100,7 @@ object EdgeRDD { * @tparam ED the edge attribute type * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD */ - def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDD[ED, VD] = { + def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDDImpl[ED, VD] = { val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) => val builder = new EdgePartitionBuilder[ED, VD] iter.foreach { e => @@ -195,8 +117,8 @@ object EdgeRDD { * @tparam ED the edge attribute type * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD */ - def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( - edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDD[ED, VD] = { - new EdgeRDD(edgePartitions) + private[graphx] def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( + edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDDImpl[ED, VD] = { + new EdgeRDDImpl(edgePartitions) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index fa4b891754c40..637791543514c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED, VD] + @transient val edges: EdgeRDD[ED] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with @@ -208,7 +208,37 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * */ def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { - mapTriplets((pid, iter) => iter.map(map)) + mapTriplets((pid, iter) => iter.map(map), TripletFields.All) + } + + /** + * Transforms each edge attribute using the map function, passing it the adjacent vertex + * attributes as well. If adjacent vertex values are not required, + * consider using `mapEdges` instead. + * + * @note This does not change the structure of the + * graph or modify the values of this graph. As a consequence + * the underlying index structures can be reused. + * + * @param map the function from an edge object to a new edge value. + * @param tripletFields which fields should be included in the edge triplet passed to the map + * function. If not all fields are needed, specifying this can improve performance. + * + * @tparam ED2 the new edge data type + * + * @example This function might be used to initialize edge + * attributes based on the attributes associated with each vertex. + * {{{ + * val rawGraph: Graph[Int, Int] = someLoadFunction() + * val graph = rawGraph.mapTriplets[Int]( edge => + * edge.src.data - edge.dst.data) + * }}} + * + */ + def mapTriplets[ED2: ClassTag]( + map: EdgeTriplet[VD, ED] => ED2, + tripletFields: TripletFields): Graph[VD, ED2] = { + mapTriplets((pid, iter) => iter.map(map), tripletFields) } /** @@ -223,12 +253,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * the underlying index structures can be reused. * * @param map the iterator transform + * @param tripletFields which fields should be included in the edge triplet passed to the map + * function. If not all fields are needed, specifying this can improve performance. * * @tparam ED2 the new edge data type * */ - def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]) - : Graph[VD, ED2] + def mapTriplets[ED2: ClassTag]( + map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] /** * Reverses all edges in the graph. If this graph contains an edge from a to b then the returned @@ -287,6 +320,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of * the map phase destined to each vertex. * + * This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead. + * * @tparam A the type of "message" to be sent to each vertex * * @param mapFunc the user defined map function which returns 0 or @@ -296,13 +331,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * be commutative and associative and is used to combine the output * of the map phase * - * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to - * consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on - * edges with destination in the active set. If the direction is `Out`, - * `mapFunc` will only be run on edges originating from vertices in the active set. If the - * direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set - * . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the - * active set. The active set must have the same index as the graph's vertices. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run only on edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. * * @example We can use this function to compute the in-degree of each * vertex @@ -319,6 +356,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * predicate or implement PageRank. * */ + @deprecated("use aggregateMessages", "1.2.0") def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, @@ -326,8 +364,80 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab : VertexRDD[A] /** - * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The - * input table should contain at most one entry for each vertex. If no entry in `other` is + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * + * @example We can use this function to compute the in-degree of each + * vertex + * {{{ + * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") + * val inDeg: RDD[(VertexId, Int)] = + * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * }}} + * + * @note By expressing computation at the edge level we achieve + * maximum parallelism. This is one of the core functions in the + * Graph API in that enables neighborhood level computation. For + * example this function can be used to count neighbors satisfying a + * predicate or implement PageRank. + * + */ + def aggregateMessages[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields = TripletFields.All) + : VertexRDD[A] = { + aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None) + } + + /** + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * This variant can take an active set to restrict the computation and is intended for internal + * use only. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run on only edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. + */ + private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]) + : VertexRDD[A] + + /** + * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. + * The input table should contain at most one entry for each vertex. If no entry in `other` is * provided for a particular vertex in the graph, the map function receives `None`. * * @tparam U the type of entry in the table of updates diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index d0dd45dba618e..116d1ea700175 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali */ private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = { if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _) + graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _) + graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None) } else { // EdgeDirection.Either - graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _) + graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _, + TripletFields.None) } } @@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = { val nbrs = if (edgeDirection == EdgeDirection.Either) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _ - ) + graph.aggregateMessages[Array[VertexId]]( + ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) }, + _ ++ _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.srcId, Array(et.dstId))), - reduceFunc = _ ++ _) + graph.aggregateMessages[Array[VertexId]]( + ctx => ctx.sendToSrc(Array(ctx.dstId)), + _ ++ _, TripletFields.None) } else if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _) + graph.aggregateMessages[Array[VertexId]]( + ctx => ctx.sendToDst(Array(ctx.srcId)), + _ ++ _, TripletFields.None) } else { throw new SparkException("It doesn't make sense to collect neighbor ids without a " + "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)") @@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @return the vertex set of neighboring vertex attributes for each vertex */ def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { - val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]]( - edge => { - val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr))) - val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr))) - edgeDirection match { - case EdgeDirection.Either => Iterator(msgToSrc, msgToDst) - case EdgeDirection.In => Iterator(msgToDst) - case EdgeDirection.Out => Iterator(msgToSrc) - case EdgeDirection.Both => - throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" + - "EdgeDirection.Either instead.") - } - }, - (a, b) => a ++ b) - - graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) => + val nbrs = edgeDirection match { + case EdgeDirection.Either => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => { + ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) + ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) + }, + (a, b) => a ++ b, TripletFields.All) + case EdgeDirection.In => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), + (a, b) => a ++ b, TripletFields.Src) + case EdgeDirection.Out => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), + (a, b) => a ++ b, TripletFields.Dst) + case EdgeDirection.Both => + throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + + "EdgeDirection.Either instead.") + } + graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) => nbrsOpt.getOrElse(Array.empty[(VertexId, VD)]) } } // end of collectNeighbor @@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = { edgeDirection match { case EdgeDirection.Either => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))), - (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => { + ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))) + ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))) + }, + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.In => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))), + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Out => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))), + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Both => throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + "EdgeDirection.Either instead.") diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java new file mode 100644 index 0000000000000..7eb4ae0f44602 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx; + +import java.io.Serializable; + +/** + * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the + * system to populate only those fields for efficiency. + */ +public class TripletFields implements Serializable { + + /** Indicates whether the source vertex attribute is included. */ + public final boolean useSrc; + + /** Indicates whether the destination vertex attribute is included. */ + public final boolean useDst; + + /** Indicates whether the edge attribute is included. */ + public final boolean useEdge; + + /** Constructs a default TripletFields in which all fields are included. */ + public TripletFields() { + this(true, true, true); + } + + public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) { + this.useSrc = useSrc; + this.useDst = useDst; + this.useEdge = useEdge; + } + + /** + * None of the triplet fields are exposed. + */ + public static final TripletFields None = new TripletFields(false, false, false); + + /** + * Expose only the edge field and not the source or destination field. + */ + public static final TripletFields EdgeOnly = new TripletFields(false, false, true); + + /** + * Expose the source and edge fields but not the destination field. (Same as Src) + */ + public static final TripletFields Src = new TripletFields(true, false, true); + + /** + * Expose the destination and edge fields but not the source field. (Same as Dst) + */ + public static final TripletFields Dst = new TripletFields(false, true, true); + + /** + * Expose all the fields (source, edge, and destination). + */ + public static final TripletFields All = new TripletFields(true, true, true); +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 2c8b245955d12..1db3df03c8052 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -27,8 +27,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock -import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._ -import org.apache.spark.graphx.impl.VertexRDDFunctions._ +import org.apache.spark.graphx.impl.VertexRDDImpl /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by @@ -55,62 +54,16 @@ import org.apache.spark.graphx.impl.VertexRDDFunctions._ * * @tparam VD the vertex attribute associated with each vertex in the set. */ -class VertexRDD[@specialized VD: ClassTag]( - val partitionsRDD: RDD[ShippableVertexPartition[VD]], - val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) - extends RDD[(VertexId, VD)](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { +abstract class VertexRDD[VD]( + @transient sc: SparkContext, + @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { - require(partitionsRDD.partitioner.isDefined) + implicit protected def vdTag: ClassTag[VD] - /** - * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting - * VertexRDD will be based on a different index and can no longer be quickly joined with this - * RDD. - */ - def reindex(): VertexRDD[VD] = this.withPartitionsRDD(partitionsRDD.map(_.reindex())) - - override val partitioner = partitionsRDD.partitioner + private[graphx] def partitionsRDD: RDD[ShippableVertexPartition[VD]] override protected def getPartitions: Array[Partition] = partitionsRDD.partitions - override protected def getPreferredLocations(s: Partition): Seq[String] = - partitionsRDD.preferredLocations(s) - - override def setName(_name: String): this.type = { - if (partitionsRDD.name != null) { - partitionsRDD.setName(partitionsRDD.name + ", " + _name) - } else { - partitionsRDD.setName(_name) - } - this - } - setName("VertexRDD") - - /** - * Persists the vertex partitions at the specified storage level, ignoring any existing target - * storage level. - */ - override def persist(newLevel: StorageLevel): this.type = { - partitionsRDD.persist(newLevel) - this - } - - override def unpersist(blocking: Boolean = true): this.type = { - partitionsRDD.unpersist(blocking) - this - } - - /** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */ - override def cache(): this.type = { - partitionsRDD.persist(targetStorageLevel) - this - } - - /** The number of vertices in the RDD. */ - override def count(): Long = { - partitionsRDD.map(_.size.toLong).reduce(_ + _) - } - /** * Provides the `RDD[(VertexId, VD)]` equivalent output. */ @@ -118,22 +71,28 @@ class VertexRDD[@specialized VD: ClassTag]( firstParent[ShippableVertexPartition[VD]].iterator(part, context).next.iterator } + /** + * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting + * VertexRDD will be based on a different index and can no longer be quickly joined with this + * RDD. + */ + def reindex(): VertexRDD[VD] + /** * Applies a function to each `VertexPartition` of this RDD and returns a new VertexRDD. */ private[graphx] def mapVertexPartitions[VD2: ClassTag]( f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2]) - : VertexRDD[VD2] = { - val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true) - this.withPartitionsRDD(newPartitionsRDD) - } - + : VertexRDD[VD2] /** * Restricts the vertex set to the set of vertices satisfying the given predicate. This operation * preserves the index for efficient joins with the original RDD, and it sets bits in the bitmask * rather than allocating new memory. * + * It is declared and defined here to allow refining the return type from `RDD[(VertexId, VD)]` to + * `VertexRDD[VD]`. + * * @param pred the user defined predicate, which takes a tuple to conform to the * `RDD[(VertexId, VD)]` interface */ @@ -149,8 +108,7 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the * original VertexRDD */ - def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] = - this.mapVertexPartitions(_.map((vid, attr) => f(attr))) + def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] /** * Maps each vertex attribute, additionally supplying the vertex ID. @@ -161,23 +119,13 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the * original VertexRDD. The resulting VertexRDD retains the same index. */ - def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = - this.mapVertexPartitions(_.map(f)) + def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] /** * Hides vertices that are the same between `this` and `other`; for vertices that are different, * keeps the values from `other`. */ - def diff(other: VertexRDD[VD]): VertexRDD[VD] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.diff(otherPart)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + def diff(other: VertexRDD[VD]): VertexRDD[VD] /** * Left joins this RDD with another VertexRDD with the same index. This function will fail if @@ -194,16 +142,7 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a VertexRDD containing the results of `f` */ def leftZipJoin[VD2: ClassTag, VD3: ClassTag] - (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.leftJoin(otherPart)(f)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] /** * Left joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is @@ -224,37 +163,14 @@ class VertexRDD[@specialized VD: ClassTag]( def leftJoin[VD2: ClassTag, VD3: ClassTag] (other: RDD[(VertexId, VD2)]) (f: (VertexId, VD, Option[VD2]) => VD3) - : VertexRDD[VD3] = { - // Test if the other vertex is a VertexRDD to choose the optimal join strategy. - // If the other set is a VertexRDD then we use the much more efficient leftZipJoin - other match { - case other: VertexRDD[_] => - leftZipJoin(other)(f) - case _ => - this.withPartitionsRDD[VD3]( - partitionsRDD.zipPartitions( - other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { - (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) - } - ) - } - } + : VertexRDD[VD3] /** * Efficiently inner joins this VertexRDD with another VertexRDD sharing the same index. See * [[innerJoin]] for the behavior of the join. */ def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U]) - (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.innerJoin(otherPart)(f)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] /** * Inner joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is @@ -268,21 +184,7 @@ class VertexRDD[@specialized VD: ClassTag]( * `this` and `other`, with values supplied by `f` */ def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) - (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { - // Test if the other vertex is a VertexRDD to choose the optimal join strategy. - // If the other set is a VertexRDD then we use the much more efficient innerZipJoin - other match { - case other: VertexRDD[_] => - innerZipJoin(other)(f) - case _ => - this.withPartitionsRDD( - partitionsRDD.zipPartitions( - other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { - (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) - } - ) - } - } + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] /** * Aggregates vertices in `messages` that have the same ids using `reduceFunc`, returning a @@ -296,38 +198,20 @@ class VertexRDD[@specialized VD: ClassTag]( * messages. */ def aggregateUsingIndex[VD2: ClassTag]( - messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { - val shuffled = messages.copartitionWithVertices(this.partitioner.get) - val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => - thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) - } - this.withPartitionsRDD[VD2](parts) - } + messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] /** * Returns a new `VertexRDD` reflecting a reversal of all edge directions in the corresponding * [[EdgeRDD]]. */ - def reverseRoutingTables(): VertexRDD[VD] = - this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse)) + def reverseRoutingTables(): VertexRDD[VD] /** Prepares this VertexRDD for efficient joins with the given EdgeRDD. */ - def withEdges(edges: EdgeRDD[_, _]): VertexRDD[VD] = { - val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get) - val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) { - (partIter, routingTableIter) => - val routingTable = - if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty - partIter.map(_.withRoutingTable(routingTable)) - } - this.withPartitionsRDD(vertexPartitions) - } + def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] /** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */ private[graphx] def withPartitionsRDD[VD2: ClassTag]( - partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = { - new VertexRDD(partitionsRDD, this.targetStorageLevel) - } + partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] /** * Changes the target storage level while preserving all other properties of the @@ -337,20 +221,14 @@ class VertexRDD[@specialized VD: ClassTag]( * [[org.apache.spark.graphx.VertexRDD#cache]] on the returned VertexRDD. */ private[graphx] def withTargetStorageLevel( - targetStorageLevel: StorageLevel): VertexRDD[VD] = { - new VertexRDD(this.partitionsRDD, targetStorageLevel) - } + targetStorageLevel: StorageLevel): VertexRDD[VD] /** Generates an RDD of vertex attributes suitable for shipping to the edge partitions. */ private[graphx] def shipVertexAttributes( - shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = { - partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst))) - } + shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] /** Generates an RDD of vertex IDs suitable for shipping to the edge partitions. */ - private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = { - partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds())) - } + private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] } // end of VertexRDD @@ -371,12 +249,12 @@ object VertexRDD { def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) } val vertexPartitions = vPartitioned.mapPartitions( iter => Iterator(ShippableVertexPartition(iter)), preservesPartitioning = true) - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } /** @@ -391,7 +269,7 @@ object VertexRDD { * @param defaultVal the vertex attribute to use when creating missing vertices */ def apply[VD: ClassTag]( - vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = { + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD): VertexRDD[VD] = { VertexRDD(vertices, edges, defaultVal, (a, b) => a) } @@ -408,11 +286,11 @@ object VertexRDD { * @param mergeFunc the commutative, associative duplicate vertex attribute merge function */ def apply[VD: ClassTag]( - vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD, mergeFunc: (VD, VD) => VD + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD, mergeFunc: (VD, VD) => VD ): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) } val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get) val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) { @@ -421,7 +299,7 @@ object VertexRDD { if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal, mergeFunc)) } - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } /** @@ -436,25 +314,25 @@ object VertexRDD { * @param defaultVal the vertex attribute to use when creating missing vertices */ def fromEdges[VD: ClassTag]( - edges: EdgeRDD[_, _], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { + edges: EdgeRDD[_], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { val routingTables = createRoutingTables(edges, new HashPartitioner(numPartitions)) val vertexPartitions = routingTables.mapPartitions({ routingTableIter => val routingTable = if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty Iterator(ShippableVertexPartition(Iterator.empty, routingTable, defaultVal)) }, preservesPartitioning = true) - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } - private def createRoutingTables( - edges: EdgeRDD[_, _], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { + private[graphx] def createRoutingTables( + edges: EdgeRDD[_], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { // Determine which vertices each edge partition needs by creating a mapping from vid to pid. val vid2pid = edges.partitionsRDD.mapPartitions(_.flatMap( Function.tupled(RoutingTablePartition.edgePartitionToMsgs))) .setName("VertexRDD.createRoutingTables - vid2pid (aggregation)") val numEdgePartitions = edges.partitions.size - vid2pid.copartitionWithVertices(vertexPartitioner).mapPartitions( + vid2pid.partitionBy(vertexPartitioner).mapPartitions( iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)), preservesPartitioning = true) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java new file mode 100644 index 0000000000000..377ae849f045c --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl; + +/** + * Criteria for filtering edges based on activeness. For internal use only. + */ +public enum EdgeActiveness { + /** Neither the source vertex nor the destination vertex need be active. */ + Neither, + /** The source vertex must be active. */ + SrcOnly, + /** The destination vertex must be active. */ + DstOnly, + /** Both vertices must be active. */ + Both, + /** At least one vertex must be active. */ + Either +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index a5c9cd1f8b4e6..373af75448374 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -21,63 +21,94 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet /** - * A collection of edges stored in columnar format, along with any vertex attributes referenced. The - * edges are stored in 3 large columnar arrays (src, dst, attribute). The arrays are clustered by - * src. There is an optional active vertex set for filtering computation on the edges. + * A collection of edges, along with referenced vertex attributes and an optional active vertex set + * for filtering computation on the edges. + * + * The edges are stored in columnar format in `localSrcIds`, `localDstIds`, and `data`. All + * referenced global vertex ids are mapped to a compact set of local vertex ids according to the + * `global2local` map. Each local vertex id is a valid index into `vertexAttrs`, which stores the + * corresponding vertex attribute, and `local2global`, which stores the reverse mapping to global + * vertex id. The global vertex ids that are active are optionally stored in `activeSet`. + * + * The edges are clustered by source vertex id, and the mapping from global vertex id to the index + * of the corresponding edge cluster is stored in `index`. * * @tparam ED the edge attribute type * @tparam VD the vertex attribute type * - * @param srcIds the source vertex id of each edge - * @param dstIds the destination vertex id of each edge + * @param localSrcIds the local source vertex id of each edge as an index into `local2global` and + * `vertexAttrs` + * @param localDstIds the local destination vertex id of each edge as an index into `local2global` + * and `vertexAttrs` * @param data the attribute associated with each edge - * @param index a clustered index on source vertex id - * @param vertices a map from referenced vertex ids to their corresponding attributes. Must - * contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for - * those vertex ids. The mask is not used. + * @param index a clustered index on source vertex id as a map from each global source vertex id to + * the offset in the edge arrays where the cluster for that vertex id begins + * @param global2local a map from referenced vertex ids to local ids which index into vertexAttrs + * @param local2global an array of global vertex ids where the offsets are local vertex ids + * @param vertexAttrs an array of vertex attributes where the offsets are local vertex ids * @param activeSet an optional active vertex set for filtering computation on the edges */ private[graphx] class EdgePartition[ @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag]( - val srcIds: Array[VertexId] = null, - val dstIds: Array[VertexId] = null, - val data: Array[ED] = null, - val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, - val vertices: VertexPartition[VD] = null, - val activeSet: Option[VertexSet] = None - ) extends Serializable { + localSrcIds: Array[Int], + localDstIds: Array[Int], + data: Array[ED], + index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + activeSet: Option[VertexSet]) + extends Serializable { - /** Return a new `EdgePartition` with the specified edge data. */ - def withData[ED2: ClassTag](data_ : Array[ED2]): EdgePartition[ED2, VD] = { - new EdgePartition(srcIds, dstIds, data_, index, vertices, activeSet) - } + /** No-arg constructor for serialization. */ + private def this() = this(null, null, null, null, null, null, null, null) - /** Return a new `EdgePartition` with the specified vertex partition. */ - def withVertices[VD2: ClassTag]( - vertices_ : VertexPartition[VD2]): EdgePartition[ED, VD2] = { - new EdgePartition(srcIds, dstIds, data, index, vertices_, activeSet) + /** Return a new `EdgePartition` with the specified edge data. */ + def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = { + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) } /** Return a new `EdgePartition` with the specified active set, provided as an iterator. */ def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = { - val newActiveSet = new VertexSet - iter.foreach(newActiveSet.add(_)) - new EdgePartition(srcIds, dstIds, data, index, vertices, Some(newActiveSet)) - } - - /** Return a new `EdgePartition` with the specified active set. */ - def withActiveSet(activeSet_ : Option[VertexSet]): EdgePartition[ED, VD] = { - new EdgePartition(srcIds, dstIds, data, index, vertices, activeSet_) + val activeSet = new VertexSet + while (iter.hasNext) { activeSet.add(iter.next()) } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, + Some(activeSet)) } /** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */ def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = { - this.withVertices(vertices.innerJoinKeepLeft(iter)) + val newVertexAttrs = new Array[VD](vertexAttrs.length) + System.arraycopy(vertexAttrs, 0, newVertexAttrs, 0, vertexAttrs.length) + while (iter.hasNext) { + val kv = iter.next() + newVertexAttrs(global2local(kv._1)) = kv._2 + } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, + activeSet) } + /** Return a new `EdgePartition` without any locally cached vertex attributes. */ + def withoutVertexAttributes[VD2: ClassTag](): EdgePartition[ED, VD2] = { + val newVertexAttrs = new Array[VD2](vertexAttrs.length) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, + activeSet) + } + + @inline private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) + + @inline private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) + + @inline private def attrs(pos: Int): ED = data(pos) + /** Look up vid in activeSet, throwing an exception if it is None. */ def isActive(vid: VertexId): Boolean = { activeSet.get.contains(vid) @@ -92,11 +123,19 @@ class EdgePartition[ * @return a new edge partition with all edges reversed. */ def reverse: EdgePartition[ED, VD] = { - val builder = new EdgePartitionBuilder(size)(classTag[ED], classTag[VD]) - for (e <- iterator) { - builder.add(e.dstId, e.srcId, e.attr) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet, size) + var i = 0 + while (i < size) { + val localSrcId = localSrcIds(i) + val localDstId = localDstIds(i) + val srcId = local2global(localSrcId) + val dstId = local2global(localDstId) + val attr = data(i) + builder.add(dstId, srcId, localDstId, localSrcId, attr) + i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -157,13 +196,25 @@ class EdgePartition[ def filter( epred: EdgeTriplet[VD, ED] => Boolean, vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = { - val filtered = tripletIterator().filter(et => - vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) - val builder = new EdgePartitionBuilder[ED, VD] - for (e <- filtered) { - builder.add(e.srcId, e.dstId, e.attr) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet) + var i = 0 + while (i < size) { + // The user sees the EdgeTriplet, so we can't reuse it and must create one per edge. + val localSrcId = localSrcIds(i) + val localDstId = localDstIds(i) + val et = new EdgeTriplet[VD, ED] + et.srcId = local2global(localSrcId) + et.dstId = local2global(localDstId) + et.srcAttr = vertexAttrs(localSrcId) + et.dstAttr = vertexAttrs(localDstId) + et.attr = data(i) + if (vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) { + builder.add(et.srcId, et.dstId, localSrcId, localDstId, et.attr) + } + i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -183,28 +234,40 @@ class EdgePartition[ * @return a new edge partition without duplicate edges */ def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = { - val builder = new EdgePartitionBuilder[ED, VD] + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet) var currSrcId: VertexId = null.asInstanceOf[VertexId] var currDstId: VertexId = null.asInstanceOf[VertexId] + var currLocalSrcId = -1 + var currLocalDstId = -1 var currAttr: ED = null.asInstanceOf[ED] + // Iterate through the edges, accumulating runs of identical edges using the curr* variables and + // releasing them to the builder when we see the beginning of the next run var i = 0 while (i < size) { if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) { + // This edge should be accumulated into the existing run currAttr = merge(currAttr, data(i)) } else { + // This edge starts a new run of edges if (i > 0) { - builder.add(currSrcId, currDstId, currAttr) + // First release the existing run to the builder + builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } + // Then start accumulating for a new run currSrcId = srcIds(i) currDstId = dstIds(i) + currLocalSrcId = localSrcIds(i) + currLocalDstId = localDstIds(i) currAttr = data(i) } i += 1 } + // Finally, release the last accumulated run if (size > 0) { - builder.add(currSrcId, currDstId, currAttr) + builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -220,7 +283,8 @@ class EdgePartition[ def innerJoin[ED2: ClassTag, ED3: ClassTag] (other: EdgePartition[ED2, _]) (f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = { - val builder = new EdgePartitionBuilder[ED3, VD] + val builder = new ExistingEdgePartitionBuilder[ED3, VD]( + global2local, local2global, vertexAttrs, activeSet) var i = 0 var j = 0 // For i = index of each edge in `this`... @@ -233,12 +297,13 @@ class EdgePartition[ while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 } if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) { // ... run `f` on the matching edge - builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j))) + builder.add(srcId, dstId, localSrcIds(i), localDstIds(i), + f(srcId, dstId, this.data(i), other.attrs(j))) } } i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -246,7 +311,7 @@ class EdgePartition[ * * @return size of the partition */ - val size: Int = srcIds.size + val size: Int = localSrcIds.size /** The number of unique source vertices in the partition. */ def indexSize: Int = index.size @@ -280,55 +345,198 @@ class EdgePartition[ * It is safe to keep references to the objects from this iterator. */ def tripletIterator( - includeSrc: Boolean = true, includeDst: Boolean = true): Iterator[EdgeTriplet[VD, ED]] = { - new EdgeTripletIterator(this, includeSrc, includeDst) + includeSrc: Boolean = true, includeDst: Boolean = true) + : Iterator[EdgeTriplet[VD, ED]] = new Iterator[EdgeTriplet[VD, ED]] { + private[this] var pos = 0 + + override def hasNext: Boolean = pos < EdgePartition.this.size + + override def next() = { + val triplet = new EdgeTriplet[VD, ED] + val localSrcId = localSrcIds(pos) + val localDstId = localDstIds(pos) + triplet.srcId = local2global(localSrcId) + triplet.dstId = local2global(localDstId) + if (includeSrc) { + triplet.srcAttr = vertexAttrs(localSrcId) + } + if (includeDst) { + triplet.dstAttr = vertexAttrs(localDstId) + } + triplet.attr = data(pos) + pos += 1 + triplet + } } /** - * Upgrade the given edge iterator into a triplet iterator. + * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning + * all edges sequentially. * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. + * @param sendMsg generates messages to neighboring vertices of an edge + * @param mergeMsg the combiner applied to messages destined to the same vertex + * @param tripletFields which triplet fields `sendMsg` uses + * @param activeness criteria for filtering edges based on activeness + * + * @return iterator aggregated messages keyed by the receiving vertex id */ - def upgradeIterator( - edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, includeDst: Boolean = true) - : Iterator[EdgeTriplet[VD, ED]] = { - new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst) + def aggregateMessagesEdgeScan[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeness: EdgeActiveness): Iterator[(VertexId, A)] = { + val aggregates = new Array[A](vertexAttrs.length) + val bitset = new BitSet(vertexAttrs.length) + + var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset) + var i = 0 + while (i < size) { + val localSrcId = localSrcIds(i) + val srcId = local2global(localSrcId) + val localDstId = localDstIds(i) + val dstId = local2global(localDstId) + val edgeIsActive = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) isActive(srcId) + else if (activeness == EdgeActiveness.DstOnly) isActive(dstId) + else if (activeness == EdgeActiveness.Both) isActive(srcId) && isActive(dstId) + else if (activeness == EdgeActiveness.Either) isActive(srcId) || isActive(dstId) + else throw new Exception("unreachable") + if (edgeIsActive) { + val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD] + val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i)) + sendMsg(ctx) + } + i += 1 + } + + bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } } /** - * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The - * iterator is generated using an index scan, so it is efficient at skipping edges that don't - * match srcIdPred. + * Send messages along edges and aggregate them at the receiving vertices. Implemented by + * filtering the source vertex index, then scanning each edge cluster. * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. - */ - def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] = - index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator)) - - /** - * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The - * cluster must start at position `index`. + * @param sendMsg generates messages to neighboring vertices of an edge + * @param mergeMsg the combiner applied to messages destined to the same vertex + * @param tripletFields which triplet fields `sendMsg` uses + * @param activeness criteria for filtering edges based on activeness * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. + * @return iterator aggregated messages keyed by the receiving vertex id */ - private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] { - private[this] val edge = new Edge[ED] - private[this] var pos = index + def aggregateMessagesIndexScan[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeness: EdgeActiveness): Iterator[(VertexId, A)] = { + val aggregates = new Array[A](vertexAttrs.length) + val bitset = new BitSet(vertexAttrs.length) + + var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset) + index.iterator.foreach { cluster => + val clusterSrcId = cluster._1 + val clusterPos = cluster._2 + val clusterLocalSrcId = localSrcIds(clusterPos) - override def hasNext: Boolean = { - pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId + val scanCluster = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) isActive(clusterSrcId) + else if (activeness == EdgeActiveness.DstOnly) true + else if (activeness == EdgeActiveness.Both) isActive(clusterSrcId) + else if (activeness == EdgeActiveness.Either) true + else throw new Exception("unreachable") + + if (scanCluster) { + var pos = clusterPos + val srcAttr = + if (tripletFields.useSrc) vertexAttrs(clusterLocalSrcId) else null.asInstanceOf[VD] + ctx.setSrcOnly(clusterSrcId, clusterLocalSrcId, srcAttr) + while (pos < size && localSrcIds(pos) == clusterLocalSrcId) { + val localDstId = localDstIds(pos) + val dstId = local2global(localDstId) + val edgeIsActive = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) true + else if (activeness == EdgeActiveness.DstOnly) isActive(dstId) + else if (activeness == EdgeActiveness.Both) isActive(dstId) + else if (activeness == EdgeActiveness.Either) isActive(clusterSrcId) || isActive(dstId) + else throw new Exception("unreachable") + if (edgeIsActive) { + val dstAttr = + if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.setRest(dstId, localDstId, dstAttr, data(pos)) + sendMsg(ctx) + } + pos += 1 + } + } } - override def next(): Edge[ED] = { - assert(srcIds(pos) == srcId) - edge.srcId = srcIds(pos) - edge.dstId = dstIds(pos) - edge.attr = data(pos) - pos += 1 - edge + bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } + } +} + +private class AggregatingEdgeContext[VD, ED, A]( + mergeMsg: (A, A) => A, + aggregates: Array[A], + bitset: BitSet) + extends EdgeContext[VD, ED, A] { + + private[this] var _srcId: VertexId = _ + private[this] var _dstId: VertexId = _ + private[this] var _localSrcId: Int = _ + private[this] var _localDstId: Int = _ + private[this] var _srcAttr: VD = _ + private[this] var _dstAttr: VD = _ + private[this] var _attr: ED = _ + + def set( + srcId: VertexId, dstId: VertexId, + localSrcId: Int, localDstId: Int, + srcAttr: VD, dstAttr: VD, + attr: ED) { + _srcId = srcId + _dstId = dstId + _localSrcId = localSrcId + _localDstId = localDstId + _srcAttr = srcAttr + _dstAttr = dstAttr + _attr = attr + } + + def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) { + _srcId = srcId + _localSrcId = localSrcId + _srcAttr = srcAttr + } + + def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) { + _dstId = dstId + _localDstId = localDstId + _dstAttr = dstAttr + _attr = attr + } + + override def srcId = _srcId + override def dstId = _dstId + override def srcAttr = _srcAttr + override def dstAttr = _dstAttr + override def attr = _attr + + override def sendToSrc(msg: A) { + send(_localSrcId, msg) + } + override def sendToDst(msg: A) { + send(_localDstId, msg) + } + + @inline private def send(localId: Int, msg: A) { + if (bitset.get(localId)) { + aggregates(localId) = mergeMsg(aggregates(localId), msg) + } else { + aggregates(localId) = msg + bitset.set(localId) } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 2b6137be25547..b0cb0fe47d461 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -25,10 +25,11 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +/** Constructs an EdgePartition from scratch. */ private[graphx] class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( size: Int = 64) { - var edges = new PrimitiveVector[Edge[ED]](size) + private[this] val edges = new PrimitiveVector[Edge[ED]](size) /** Add a new edge to the partition. */ def add(src: VertexId, dst: VertexId, d: ED) { @@ -38,8 +39,67 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla def toEdgePartition: EdgePartition[ED, VD] = { val edgeArray = edges.trim().array Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering) - val srcIds = new Array[VertexId](edgeArray.size) - val dstIds = new Array[VertexId](edgeArray.size) + val localSrcIds = new Array[Int](edgeArray.size) + val localDstIds = new Array[Int](edgeArray.size) + val data = new Array[ED](edgeArray.size) + val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] + val global2local = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] + val local2global = new PrimitiveVector[VertexId] + var vertexAttrs = Array.empty[VD] + // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and + // adding them to the index. Also populate a map from vertex id to a sequential local offset. + if (edgeArray.length > 0) { + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId + var currLocalId = -1 + var i = 0 + while (i < edgeArray.size) { + val srcId = edgeArray(i).srcId + val dstId = edgeArray(i).dstId + localSrcIds(i) = global2local.changeValue(srcId, + { currLocalId += 1; local2global += srcId; currLocalId }, identity) + localDstIds(i) = global2local.changeValue(dstId, + { currLocalId += 1; local2global += dstId; currLocalId }, identity) + data(i) = edgeArray(i).attr + if (srcId != currSrcId) { + currSrcId = srcId + index.update(currSrcId, i) + } + + i += 1 + } + vertexAttrs = new Array[VD](currLocalId + 1) + } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs, + None) + } +} + +/** + * Constructs an EdgePartition from an existing EdgePartition with the same vertex set. This enables + * reuse of the local vertex ids. Intended for internal use in EdgePartition only. + */ +private[impl] +class ExistingEdgePartitionBuilder[ + @specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + activeSet: Option[VertexSet], + size: Int = 64) { + private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) + + /** Add a new edge to the partition. */ + def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) { + edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d) + } + + def toEdgePartition: EdgePartition[ED, VD] = { + val edgeArray = edges.trim().array + Sorting.quickSort(edgeArray)(EdgeWithLocalIds.lexicographicOrdering) + val localSrcIds = new Array[Int](edgeArray.size) + val localDstIds = new Array[Int](edgeArray.size) val data = new Array[ED](edgeArray.size) val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and @@ -49,8 +109,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { - srcIds(i) = edgeArray(i).srcId - dstIds(i) = edgeArray(i).dstId + localSrcIds(i) = edgeArray(i).localSrcId + localDstIds(i) = edgeArray(i).localDstId data(i) = edgeArray(i).attr if (edgeArray(i).srcId != currSrcId) { currSrcId = edgeArray(i).srcId @@ -60,13 +120,24 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla } } - // Create and populate a VertexPartition with vids from the edges, but no attributes - val vidsIter = srcIds.iterator ++ dstIds.iterator - val vertexIds = new OpenHashSet[VertexId] - vidsIter.foreach(vid => vertexIds.add(vid)) - val vertices = new VertexPartition( - vertexIds, new Array[VD](vertexIds.capacity), vertexIds.getBitSet) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) + } +} - new EdgePartition(srcIds, dstIds, data, index, vertices) +private[impl] case class EdgeWithLocalIds[@specialized ED]( + srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED) + +private[impl] object EdgeWithLocalIds { + implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] { + override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = { + if (a.srcId == b.srcId) { + if (a.dstId == b.dstId) 0 + else if (a.dstId < b.dstId) -1 + else 1 + } else if (a.srcId < b.srcId) -1 + else 1 + } } + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala new file mode 100644 index 0000000000000..a8169613b4fd2 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.{classTag, ClassTag} + +import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +import org.apache.spark.graphx._ + +class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( + override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], + val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) + extends EdgeRDD[ED](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { + + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("EdgeRDD") + + /** + * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the + * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new + * partitioner that allows co-partitioning with `partitionsRDD`. + */ + override val partitioner = + partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + + override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() + + /** + * Persists the edge partitions at the specified storage level, ignoring any existing target + * storage level. + */ + override def persist(newLevel: StorageLevel): this.type = { + partitionsRDD.persist(newLevel) + this + } + + override def unpersist(blocking: Boolean = true): this.type = { + partitionsRDD.unpersist(blocking) + this + } + + /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ + override def cache(): this.type = { + partitionsRDD.persist(targetStorageLevel) + this + } + + /** The number of edges in the RDD. */ + override def count(): Long = { + partitionsRDD.map(_._2.size.toLong).reduce(_ + _) + } + + override def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDDImpl[ED2, VD] = + mapEdgePartitions((pid, part) => part.map(f)) + + override def reverse: EdgeRDDImpl[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) + + def filter( + epred: EdgeTriplet[VD, ED] => Boolean, + vpred: (VertexId, VD) => Boolean): EdgeRDDImpl[ED, VD] = { + mapEdgePartitions((pid, part) => part.filter(epred, vpred)) + } + + override def innerJoin[ED2: ClassTag, ED3: ClassTag] + (other: EdgeRDD[ED2]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDDImpl[ED3, VD] = { + val ed2Tag = classTag[ED2] + val ed3Tag = classTag[ED3] + this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { + (thisIter, otherIter) => + val (pid, thisEPart) = thisIter.next() + val (_, otherEPart) = otherIter.next() + Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag))) + }) + } + + def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( + f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDDImpl[ED2, VD2] = { + this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => + if (iter.hasNext) { + val (pid, ep) = iter.next() + Iterator(Tuple2(pid, f(pid, ep))) + } else { + Iterator.empty + } + }, preservesPartitioning = true)) + } + + private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( + partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDDImpl[ED2, VD2] = { + new EdgeRDDImpl(partitionsRDD, this.targetStorageLevel) + } + + override private[graphx] def withTargetStorageLevel( + targetStorageLevel: StorageLevel): EdgeRDDImpl[ED, VD] = { + new EdgeRDDImpl(this.partitionsRDD, targetStorageLevel) + } + +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala deleted file mode 100644 index 56f79a7097fce..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.reflect.ClassTag - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -/** - * The Iterator type returned when constructing edge triplets. This could be an anonymous class in - * EdgePartition.tripletIterator, but we name it here explicitly so it is easier to debug / profile. - */ -private[impl] -class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - // Current position in the array. - private var pos = 0 - - override def hasNext: Boolean = pos < edgePartition.size - - override def next() = { - val triplet = new EdgeTriplet[VD, ED] - triplet.srcId = edgePartition.srcIds(pos) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertices(triplet.srcId) - } - triplet.dstId = edgePartition.dstIds(pos) - if (includeDst) { - triplet.dstAttr = edgePartition.vertices(triplet.dstId) - } - triplet.attr = edgePartition.data(pos) - pos += 1 - triplet - } -} - -/** - * An Iterator type for internal use that reuses EdgeTriplet objects. This could be an anonymous - * class in EdgePartition.upgradeIterator, but we name it here explicitly so it is easier to debug / - * profile. - */ -private[impl] -class ReusingEdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgeIter: Iterator[Edge[ED]], - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - private val triplet = new EdgeTriplet[VD, ED] - - override def hasNext = edgeIter.hasNext - - override def next() = { - triplet.set(edgeIter.next()) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertices(triplet.srcId) - } - if (includeDst) { - triplet.dstAttr = edgePartition.vertices(triplet.dstId) - } - triplet - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 33f35cfb69a26..0eae2a673874a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -23,7 +23,6 @@ import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.util.BytecodeUtils @@ -44,7 +43,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( /** Default constructor is provided to support serialization */ protected def this() = this(null, null) - @transient override val edges: EdgeRDD[ED, VD] = replicatedVertexView.edges + @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges /** Return a RDD that brings edges together with their source and destination vertices. */ @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { @@ -127,13 +126,12 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } override def mapTriplets[ED2: ClassTag]( - f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = { + f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] = { vertices.cache() - val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) => - part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr))) + part.map(f(pid, part.tripletIterator(tripletFields.useSrc, tripletFields.useDst))) } new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges)) } @@ -171,15 +169,38 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { + + def sendMsg(ctx: EdgeContext[VD, ED, A]) { + mapFunc(ctx.toEdgeTriplet).foreach { kv => + val id = kv._1 + val msg = kv._2 + if (id == ctx.srcId) { + ctx.sendToSrc(msg) + } else { + assert(id == ctx.dstId) + ctx.sendToDst(msg) + } + } + } - vertices.cache() + val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") + val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") + val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true) + + aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt) + } + + override def aggregateMessagesWithActiveSet[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { + vertices.cache() // For each vertex, replicate its attribute only to partitions where it is // in the relevant position in an edge. - val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val view = activeSetOpt match { case Some((activeSet, _)) => replicatedVertexView.withActiveSet(activeSet) @@ -193,42 +214,40 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( case (pid, edgePartition) => // Choose scan method val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat - val edgeIter = activeDirectionOpt match { + activeDirectionOpt match { case Some(EdgeDirection.Both) => if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) - .filter(e => edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Both) } else { - edgePartition.iterator.filter(e => - edgePartition.isActive(e.srcId) && edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Both) } case Some(EdgeDirection.Either) => // TODO: Because we only have a clustered index on the source vertex ID, we can't filter // the index here. Instead we have to scan all edges and then do the filter. - edgePartition.iterator.filter(e => - edgePartition.isActive(e.srcId) || edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Either) case Some(EdgeDirection.Out) => if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.SrcOnly) } else { - edgePartition.iterator.filter(e => edgePartition.isActive(e.srcId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.SrcOnly) } case Some(EdgeDirection.In) => - edgePartition.iterator.filter(e => edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.DstOnly) case _ => // None - edgePartition.iterator + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Neither) } - - // Scan edges and run the map function - val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr) - .flatMap(mapFunc(_)) - // Note: This doesn't allow users to send messages to arbitrary vertices. - edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator - }).setName("GraphImpl.mapReduceTriplets - preAgg") + }).setName("GraphImpl.aggregateMessages - preAgg") // do the final reduction reusing the index map - vertices.aggregateUsingIndex(preAgg, reduceFunc) - } // end of mapReduceTriplets + vertices.aggregateUsingIndex(preAgg, mergeMsg) + } override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) @@ -304,11 +323,10 @@ object GraphImpl { */ def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = { + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { // Convert the vertex partitions in edges to the correct type - val newEdges = edges.mapEdgePartitions( - (pid, part) => part.withVertices(part.vertices.map( - (vid, attr) => null.asInstanceOf[VD]))) + val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] + .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) GraphImpl.fromExistingRDDs(vertices, newEdges) } @@ -319,8 +337,8 @@ object GraphImpl { */ def fromExistingRDDs[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED, VD]): GraphImpl[VD, ED] = { - new GraphImpl(vertices, new ReplicatedVertexView(edges)) + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + new GraphImpl(vertices, new ReplicatedVertexView(edges.asInstanceOf[EdgeRDDImpl[ED, VD]])) } /** @@ -328,7 +346,7 @@ object GraphImpl { * `defaultVertexAttr`. The vertices will have the same number of partitions as the EdgeRDD. */ private def fromEdgeRDD[VD: ClassTag, ED: ClassTag]( - edges: EdgeRDD[ED, VD], + edges: EdgeRDDImpl[ED, VD], defaultVertexAttr: VD, edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala deleted file mode 100644 index 714f3b81c9dad..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.language.implicitConversions -import scala.reflect.{classTag, ClassTag} - -import org.apache.spark.Partitioner -import org.apache.spark.graphx.{PartitionID, VertexId} -import org.apache.spark.rdd.{ShuffledRDD, RDD} - - -private[graphx] -class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { - def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner) - - // Set a custom serializer if the data is of int or double type. - if (classTag[VD] == ClassTag.Int) { - rdd.setSerializer(new IntAggMsgSerializer) - } else if (classTag[VD] == ClassTag.Long) { - rdd.setSerializer(new LongAggMsgSerializer) - } else if (classTag[VD] == ClassTag.Double) { - rdd.setSerializer(new DoubleAggMsgSerializer) - } - rdd - } -} - -private[graphx] -object VertexRDDFunctions { - implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = { - new VertexRDDFunctions(rdd) - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 86b366eb9202b..8ab255bd4038c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx._ */ private[impl] class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( - var edges: EdgeRDD[ED, VD], + var edges: EdgeRDDImpl[ED, VD], var hasSrcId: Boolean = false, var hasDstId: Boolean = false) { @@ -42,7 +42,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * shipping level. */ def withEdges[VD2: ClassTag, ED2: ClassTag]( - edges_ : EdgeRDD[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { + edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { new ReplicatedVertexView(edges_, hasSrcId, hasDstId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index b27485953f719..eb3c997e0f3c0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -29,24 +29,6 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage -private[graphx] -class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { - /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ - def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, Int, Int]( - self, partitioner).setSerializer(new RoutingTableMessageSerializer) - } -} - -private[graphx] -object RoutingTableMessageRDDFunctions { - import scala.language.implicitConversions - - implicit def rdd2RoutingTableMessageRDDFunctions(rdd: RDD[RoutingTableMessage]) = { - new RoutingTableMessageRDDFunctions(rdd) - } -} - private[graphx] object RoutingTablePartition { /** @@ -74,11 +56,9 @@ object RoutingTablePartition { // Determine which positions each vertex id appears in using a map where the low 2 bits // represent src and dst val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte] - edgePartition.srcIds.iterator.foreach { srcId => - map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte) - } - edgePartition.dstIds.iterator.foreach { dstId => - map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) + edgePartition.iterator.foreach { e => + map.changeValue(e.srcId, 0x1, (b: Byte) => (b | 0x1).toByte) + map.changeValue(e.dstId, 0x2, (b: Byte) => (b | 0x2).toByte) } map.iterator.map { vidAndPosition => val vid = vidAndPosition._1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala deleted file mode 100644 index 3909efcdfc993..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ /dev/null @@ -1,369 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.language.existentials - -import java.io.{EOFException, InputStream, OutputStream} -import java.nio.ByteBuffer - -import scala.reflect.ClassTag - -import org.apache.spark.serializer._ - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage - -private[graphx] -class RoutingTableMessageSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream): SerializationStream = - new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T): SerializationStream = { - val msg = t.asInstanceOf[RoutingTableMessage] - writeVarLong(msg._1, optimizePositive = false) - writeInt(msg._2) - this - } - } - - override def deserializeStream(s: InputStream): DeserializationStream = - new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readInt() - (a, b).asInstanceOf[T] - } - } - } -} - -private[graphx] -class VertexIdMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, _)] - writeVarLong(msg._1, optimizePositive = false) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - (readVarLong(optimizePositive = false), null).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Int]. */ -private[graphx] -class IntAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Int)] - writeVarLong(msg._1, optimizePositive = false) - writeUnsignedVarInt(msg._2) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readUnsignedVarInt() - (a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Long]. */ -private[graphx] -class LongAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Long)] - writeVarLong(msg._1, optimizePositive = false) - writeVarLong(msg._2, optimizePositive = true) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readVarLong(optimizePositive = true) - (a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Double]. */ -private[graphx] -class DoubleAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Double)] - writeVarLong(msg._1, optimizePositive = false) - writeDouble(msg._2) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readDouble() - (a, b).asInstanceOf[T] - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Helper classes to shorten the implementation of those special serializers. -//////////////////////////////////////////////////////////////////////////////// - -private[graphx] -abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream { - // The implementation should override this one. - def writeObject[T: ClassTag](t: T): SerializationStream - - def writeInt(v: Int) { - s.write(v >> 24) - s.write(v >> 16) - s.write(v >> 8) - s.write(v) - } - - def writeUnsignedVarInt(value: Int) { - if ((value >>> 7) == 0) { - s.write(value.toInt) - } else if ((value >>> 14) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7) - } else if ((value >>> 21) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14) - } else if ((value >>> 28) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14 | 0x80) - s.write(value >>> 21) - } else { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14 | 0x80) - s.write(value >>> 21 | 0x80) - s.write(value >>> 28) - } - } - - def writeVarLong(value: Long, optimizePositive: Boolean) { - val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value - if ((v >>> 7) == 0) { - s.write(v.toInt) - } else if ((v >>> 14) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7).toInt) - } else if ((v >>> 21) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14).toInt) - } else if ((v >>> 28) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21).toInt) - } else if ((v >>> 35) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28).toInt) - } else if ((v >>> 42) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35).toInt) - } else if ((v >>> 49) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42).toInt) - } else if ((v >>> 56) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42 | 0x80).toInt) - s.write((v >>> 49).toInt) - } else { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42 | 0x80).toInt) - s.write((v >>> 49 | 0x80).toInt) - s.write((v >>> 56).toInt) - } - } - - def writeLong(v: Long) { - s.write((v >>> 56).toInt) - s.write((v >>> 48).toInt) - s.write((v >>> 40).toInt) - s.write((v >>> 32).toInt) - s.write((v >>> 24).toInt) - s.write((v >>> 16).toInt) - s.write((v >>> 8).toInt) - s.write(v.toInt) - } - - def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v)) - - override def flush(): Unit = s.flush() - - override def close(): Unit = s.close() -} - -private[graphx] -abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream { - // The implementation should override this one. - def readObject[T: ClassTag](): T - - def readInt(): Int = { - val first = s.read() - if (first < 0) throw new EOFException - (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) - } - - def readUnsignedVarInt(): Int = { - var value: Int = 0 - var i: Int = 0 - def readOrThrow(): Int = { - val in = s.read() - if (in < 0) throw new EOFException - in & 0xFF - } - var b: Int = readOrThrow() - while ((b & 0x80) != 0) { - value |= (b & 0x7F) << i - i += 7 - if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long") - b = readOrThrow() - } - value | (b << i) - } - - def readVarLong(optimizePositive: Boolean): Long = { - def readOrThrow(): Int = { - val in = s.read() - if (in < 0) throw new EOFException - in & 0xFF - } - var b = readOrThrow() - var ret: Long = b & 0x7F - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 7 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 14 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 21 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 28 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 35 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 42 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 49 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= b.toLong << 56 - } - } - } - } - } - } - } - } - if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret - } - - def readLong(): Long = { - val first = s.read() - if (first < 0) throw new EOFException() - (first.toLong << 56) | - (s.read() & 0xFF).toLong << 48 | - (s.read() & 0xFF).toLong << 40 | - (s.read() & 0xFF).toLong << 32 | - (s.read() & 0xFF).toLong << 24 | - (s.read() & 0xFF) << 16 | - (s.read() & 0xFF) << 8 | - (s.read() & 0xFF) - } - - def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong()) - - override def close(): Unit = s.close() -} - -private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance { - - override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException - - // The implementation should override the following two. - override def serializeStream(s: OutputStream): SerializationStream - override def deserializeStream(s: InputStream): DeserializationStream -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala new file mode 100644 index 0000000000000..d92a55a189298 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel + +import org.apache.spark.graphx._ + +class VertexRDDImpl[VD] private[graphx] ( + val partitionsRDD: RDD[ShippableVertexPartition[VD]], + val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) + (implicit override protected val vdTag: ClassTag[VD]) + extends VertexRDD[VD](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { + + require(partitionsRDD.partitioner.isDefined) + + override def reindex(): VertexRDD[VD] = this.withPartitionsRDD(partitionsRDD.map(_.reindex())) + + override val partitioner = partitionsRDD.partitioner + + override protected def getPreferredLocations(s: Partition): Seq[String] = + partitionsRDD.preferredLocations(s) + + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("VertexRDD") + + /** + * Persists the vertex partitions at the specified storage level, ignoring any existing target + * storage level. + */ + override def persist(newLevel: StorageLevel): this.type = { + partitionsRDD.persist(newLevel) + this + } + + override def unpersist(blocking: Boolean = true): this.type = { + partitionsRDD.unpersist(blocking) + this + } + + /** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */ + override def cache(): this.type = { + partitionsRDD.persist(targetStorageLevel) + this + } + + /** The number of vertices in the RDD. */ + override def count(): Long = { + partitionsRDD.map(_.size).reduce(_ + _) + } + + override private[graphx] def mapVertexPartitions[VD2: ClassTag]( + f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2]) + : VertexRDD[VD2] = { + val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true) + this.withPartitionsRDD(newPartitionsRDD) + } + + override def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] = + this.mapVertexPartitions(_.map((vid, attr) => f(attr))) + + override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = + this.mapVertexPartitions(_.map(f)) + + override def diff(other: VertexRDD[VD]): VertexRDD[VD] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.diff(otherPart)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def leftZipJoin[VD2: ClassTag, VD3: ClassTag] + (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.leftJoin(otherPart)(f)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def leftJoin[VD2: ClassTag, VD3: ClassTag] + (other: RDD[(VertexId, VD2)]) + (f: (VertexId, VD, Option[VD2]) => VD3) + : VertexRDD[VD3] = { + // Test if the other vertex is a VertexRDD to choose the optimal join strategy. + // If the other set is a VertexRDD then we use the much more efficient leftZipJoin + other match { + case other: VertexRDD[_] => + leftZipJoin(other)(f) + case _ => + this.withPartitionsRDD[VD3]( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) + } + ) + } + } + + override def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U]) + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.innerJoin(otherPart)(f)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { + // Test if the other vertex is a VertexRDD to choose the optimal join strategy. + // If the other set is a VertexRDD then we use the much more efficient innerZipJoin + other match { + case other: VertexRDD[_] => + innerZipJoin(other)(f) + case _ => + this.withPartitionsRDD( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) + } + ) + } + } + + override def aggregateUsingIndex[VD2: ClassTag]( + messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { + val shuffled = messages.partitionBy(this.partitioner.get) + val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => + thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) + } + this.withPartitionsRDD[VD2](parts) + } + + override def reverseRoutingTables(): VertexRDD[VD] = + this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse)) + + override def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] = { + val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get) + val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) { + (partIter, routingTableIter) => + val routingTable = + if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty + partIter.map(_.withRoutingTable(routingTable)) + } + this.withPartitionsRDD(vertexPartitions) + } + + override private[graphx] def withPartitionsRDD[VD2: ClassTag]( + partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = { + new VertexRDDImpl(partitionsRDD, this.targetStorageLevel) + } + + override private[graphx] def withTargetStorageLevel( + targetStorageLevel: StorageLevel): VertexRDD[VD] = { + new VertexRDDImpl(this.partitionsRDD, targetStorageLevel) + } + + override private[graphx] def shipVertexAttributes( + shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst))) + } + + override private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds())) + } + +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 257e2f3a36115..e139959c3f5c1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -85,7 +85,7 @@ object PageRank extends Logging { // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree - .mapTriplets( e => 1.0 / e.srcAttr ) + .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.Src ) // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) @@ -96,8 +96,8 @@ object PageRank extends Logging { // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation. - val rankUpdates = rankGraph.mapReduceTriplets[Double]( - e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _) + val rankUpdates = rankGraph.aggregateMessages[Double]( + ctx => ctx.sendToDst(ctx.srcAttr * ctx.attr), _ + _, TripletFields.Src) // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index ccd7de537b6e3..f58587e10a820 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -74,9 +74,9 @@ object SVDPlusPlus { var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() // Calculate initial bias and norm - val t0 = g.mapReduceTriplets( - et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), - (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) + val t0 = g.aggregateMessages[(Long, Double)]( + ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, + (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) g = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), @@ -84,15 +84,17 @@ object SVDPlusPlus { (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) } - def mapTrainF(conf: Conf, u: Double) - (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) - : Iterator[(VertexId, (DoubleMatrix, DoubleMatrix, Double))] = { - val (usr, itm) = (et.srcAttr, et.dstAttr) + def sendMsgTrainF(conf: Conf, u: Double) + (ctx: EdgeContext[ + (DoubleMatrix, DoubleMatrix, Double, Double), + Double, + (DoubleMatrix, DoubleMatrix, Double)]) { + val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) - val err = et.attr - pred + val err = ctx.attr - pred val updateP = q.mul(err) .subColumnVector(p.mul(conf.gamma7)) .mul(conf.gamma2) @@ -102,16 +104,16 @@ object SVDPlusPlus { val updateY = q.mul(err * usr._4) .subColumnVector(itm._2.mul(conf.gamma7)) .mul(conf.gamma2) - Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)), - (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))) + ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) + ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) } for (i <- 0 until conf.maxIters) { // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes g.cache() - val t1 = g.mapReduceTriplets( - et => Iterator((et.srcId, et.dstAttr._2)), - (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2)) + val t1 = g.aggregateMessages[DoubleMatrix]( + ctx => ctx.sendToSrc(ctx.dstAttr._2), + (g1, g2) => g1.addColumnVector(g2)) g = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => @@ -121,8 +123,8 @@ object SVDPlusPlus { // Phase 2, update p for user nodes and q, y for item nodes g.cache() - val t2 = g.mapReduceTriplets( - mapTrainF(conf, u), + val t2 = g.aggregateMessages( + sendMsgTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) g = g.outerJoinVertices(t2) { @@ -135,20 +137,18 @@ object SVDPlusPlus { } // calculate error on training set - def mapTestF(conf: Conf, u: Double) - (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) - : Iterator[(VertexId, Double)] = - { - val (usr, itm) = (et.srcAttr, et.dstAttr) + def sendMsgTestF(conf: Conf, u: Double) + (ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) { + val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) - val err = (et.attr - pred) * (et.attr - pred) - Iterator((et.dstId, err)) + val err = (ctx.attr - pred) * (ctx.attr - pred) + ctx.sendToDst(err) } g.cache() - val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) + val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) g = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index 7c396e6e66a28..daf162085e3e4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -61,26 +61,27 @@ object TriangleCount { (vid, _, optSet) => optSet.getOrElse(null) } // Edge function computes intersection of smaller vertex with larger vertex - def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Iterator[(VertexId, Int)] = { - assert(et.srcAttr != null) - assert(et.dstAttr != null) - val (smallSet, largeSet) = if (et.srcAttr.size < et.dstAttr.size) { - (et.srcAttr, et.dstAttr) + def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) { + assert(ctx.srcAttr != null) + assert(ctx.dstAttr != null) + val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) { + (ctx.srcAttr, ctx.dstAttr) } else { - (et.dstAttr, et.srcAttr) + (ctx.dstAttr, ctx.srcAttr) } val iter = smallSet.iterator var counter: Int = 0 while (iter.hasNext) { val vid = iter.next() - if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) { + if (vid != ctx.srcId && vid != ctx.dstId && largeSet.contains(vid)) { counter += 1 } } - Iterator((et.srcId, counter), (et.dstId, counter)) + ctx.sendToSrc(counter) + ctx.sendToDst(counter) } // compute the intersection along edges - val counters: VertexRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _ + _) + val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _) // Merge counters with the graph and divide by two since each triangle is counted twice g.outerJoinVertices(counters) { (vid, _, optCounter: Option[Int]) => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 6506bac73d71c..a05d1ddb21295 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -118,7 +118,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Each vertex should be replicated to at most 2 * sqrt(p) partitions val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 - Iterator((part.srcIds ++ part.dstIds).toSet) + Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet) }.collect if (!verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) { val numFailures = verts.count(id => partitionSets.count(_.contains(id)) > bound) @@ -130,7 +130,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // This should not be true for the default hash partitioning val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 - Iterator((part.srcIds ++ part.dstIds).toSet) + Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet) }.collect assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound)) @@ -318,6 +318,21 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("aggregateMessages") { + withSpark { sc => + val n = 5 + val agg = starGraph(sc, n).aggregateMessages[String]( + ctx => { + if (ctx.dstAttr != null) { + throw new Exception( + "expected ctx.dstAttr to be null due to TripletFields, but it was " + ctx.dstAttr) + } + ctx.sendToDst(ctx.srcAttr) + }, _ + _, TripletFields.Src) + assert(agg.collect().toSet === (1 to n).map(x => (x: VertexId, "v")).toSet) + } + } + test("outerJoinVertices") { withSpark { sc => val n = 5 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala deleted file mode 100644 index 864cb1fdf0022..0000000000000 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream} - -import scala.util.Random -import scala.reflect.ClassTag - -import org.scalatest.FunSuite - -import org.apache.spark._ -import org.apache.spark.graphx.impl._ -import org.apache.spark.serializer.SerializationStream - - -class SerializerSuite extends FunSuite with LocalSparkContext { - - test("IntAggMsgSerializer") { - val outMsg = (4: VertexId, 5) - val bout = new ByteArrayOutputStream - val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Int) = inStrm.readObject() - val inMsg2: (VertexId, Int) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("LongAggMsgSerializer") { - val outMsg = (4: VertexId, 1L << 32) - val bout = new ByteArrayOutputStream - val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Long) = inStrm.readObject() - val inMsg2: (VertexId, Long) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("DoubleAggMsgSerializer") { - val outMsg = (4: VertexId, 5.0) - val bout = new ByteArrayOutputStream - val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Double) = inStrm.readObject() - val inMsg2: (VertexId, Double) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("variable long encoding") { - def testVarLongEncoding(v: Long, optimizePositive: Boolean) { - val bout = new ByteArrayOutputStream - val stream = new ShuffleSerializationStream(bout) { - def writeObject[T: ClassTag](t: T): SerializationStream = { - writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive) - this - } - } - stream.writeObject(v) - - val bin = new ByteArrayInputStream(bout.toByteArray) - val dstream = new ShuffleDeserializationStream(bin) { - def readObject[T: ClassTag](): T = { - readVarLong(optimizePositive).asInstanceOf[T] - } - } - val read = dstream.readObject[Long]() - assert(read === v) - } - - // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference) - val d = Random.nextLong() % 128 - Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d, - 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number => - testVarLongEncoding(number, optimizePositive = false) - testVarLongEncoding(number, optimizePositive = true) - testVarLongEncoding(-number, optimizePositive = false) - testVarLongEncoding(-number, optimizePositive = true) - } - } -} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index db1dac6160080..515f3a9cd02eb 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -82,29 +82,6 @@ class EdgePartitionSuite extends FunSuite { assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges) } - test("upgradeIterator") { - val edges = List((0, 1, 0), (1, 0, 0)) - val verts = List((0L, 1), (1L, 2)) - val part = makeEdgePartition(edges).updateVertices(verts.iterator) - assert(part.upgradeIterator(part.iterator).map(_.toTuple).toList === - part.tripletIterator().toList.map(_.toTuple)) - } - - test("indexIterator") { - val edgesFrom0 = List(Edge(0, 1, 0)) - val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0)) - val sortedEdges = edgesFrom0 ++ edgesFrom1 - val builder = new EdgePartitionBuilder[Int, Nothing] - for (e <- Random.shuffle(sortedEdges)) { - builder.add(e.srcId, e.dstId, e.attr) - } - - val edgePartition = builder.toEdgePartition - assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges) - assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0) - assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1) - } - test("innerJoin") { val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0)) @@ -125,8 +102,18 @@ class EdgePartitionSuite extends FunSuite { assert(ep.numActives == Some(2)) } + test("tripletIterator") { + val builder = new EdgePartitionBuilder[Int, Int] + builder.add(1, 2, 0) + builder.add(1, 3, 0) + builder.add(1, 4, 0) + val ep = builder.toEdgePartition + val result = ep.tripletIterator().toList.map(et => (et.srcId, et.dstId)) + assert(result === Seq((1, 2), (1, 3), (1, 4))) + } + test("serialization") { - val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) + val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5)) val a: EdgePartition[Int, Int] = makeEdgePartition(aList) val javaSer = new JavaSerializer(new SparkConf()) val conf = new SparkConf() @@ -135,11 +122,7 @@ class EdgePartitionSuite extends FunSuite { for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) - assert(aSer.srcIds.toList === a.srcIds.toList) - assert(aSer.dstIds.toList === a.dstIds.toList) - assert(aSer.data.toList === a.data.toList) - assert(aSer.index != null) - assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet) + assert(aSer.tripletIterator().toList === a.tripletIterator().toList) } } } diff --git a/make-distribution.sh b/make-distribution.sh index 0bc839e1dbe4d..45c99e42e5a5b 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -59,7 +59,7 @@ while (( "$#" )); do exit_with_usage ;; --with-hive) - echo "Error: '--with-hive' is no longer supported, use Maven option -Phive" + echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; --skip-java-test) @@ -119,7 +119,7 @@ VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v " SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles $@ 2>/dev/null\ +SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ @@ -181,6 +181,9 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +# This will fail if the -Pyarn profile is not provided +# In this case, silence the error and ignore the return code of this command +cp "$FWDIR"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" @@ -198,6 +201,9 @@ if [ -e "$FWDIR"/CHANGES.txt ]; then cp "$FWDIR/CHANGES.txt" "$DISTDIR" fi +# Copy data files +cp -r "$FWDIR/data" "$DISTDIR" + # Copy other things mkdir "$DISTDIR"/conf cp "$FWDIR"/conf/*.template "$DISTDIR"/conf diff --git a/mllib/pom.xml b/mllib/pom.xml index 87a7ddaba97f2..0a6dda0ab8c80 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -100,6 +100,11 @@ junit-interface test + + org.mockito + mockito-all + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala new file mode 100644 index 0000000000000..fdbee743e8177 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.api.java.JavaSchemaRDD + +/** + * :: AlphaComponent :: + * Abstract class for estimators that fit models to data. + */ +@AlphaComponent +abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset input dataset + * @param paramPairs optional list of param pairs (overwrite embedded params) + * @return fitted model + */ + @varargs + def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + val map = new ParamMap().put(paramPairs: _*) + fit(dataset, map) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted model + */ + def fit(dataset: SchemaRDD, paramMap: ParamMap): M + + /** + * Fits multiple models to the input data with multiple sets of parameters. + * The default implementation uses a for loop on each parameter map. + * Subclasses could overwrite this to optimize multi-model training. + * + * @param dataset input dataset + * @param paramMaps an array of parameter maps + * @return fitted models, matching the input parameter maps + */ + def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + paramMaps.map(fit(dataset, _)) + } + + // Java-friendly versions of fit. + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset input dataset + * @param paramPairs optional list of param pairs (overwrite embedded params) + * @return fitted model + */ + @varargs + def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = { + fit(dataset.schemaRDD, paramPairs: _*) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted model + */ + def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = { + fit(dataset.schemaRDD, paramMap) + } + + /** + * Fits multiple models to the input data with multiple sets of parameters. + * + * @param dataset input dataset + * @param paramMaps an array of parameter maps + * @return fitted models, matching the input parameter maps + */ + def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = { + fit(dataset.schemaRDD, paramMaps).asJava + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala new file mode 100644 index 0000000000000..db563dd550e56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +/** + * :: AlphaComponent :: + * Abstract class for evaluators that compute metrics from predictions. + */ +@AlphaComponent +abstract class Evaluator extends Identifiable { + + /** + * Evaluates the output. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @param paramMap parameter map that specifies the input columns and output metrics + * @return metric + */ + def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala new file mode 100644 index 0000000000000..cd84b05bfb496 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import java.util.UUID + +/** + * Object with a unique id. + */ +private[ml] trait Identifiable extends Serializable { + + /** + * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * random hex chars. + */ + private[ml] val uid: String = + this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala new file mode 100644 index 0000000000000..cae5082b51196 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.ParamMap + +/** + * :: AlphaComponent :: + * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. + * + * @tparam M model type + */ +@AlphaComponent +abstract class Model[M <: Model[M]] extends Transformer { + /** + * The parent estimator that produced this model. + */ + val parent: Estimator[M] + + /** + * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. + */ + val fittingParamMap: ParamMap +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala new file mode 100644 index 0000000000000..e545df1e37b9c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.collection.mutable.ListBuffer + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Params, Param, ParamMap} +import org.apache.spark.sql.{SchemaRDD, StructType} + +/** + * :: AlphaComponent :: + * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. + */ +@AlphaComponent +abstract class PipelineStage extends Serializable with Logging { + + /** + * Derives the output schema from the input schema and parameters. + */ + private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType + + /** + * Derives the output schema from the input schema and parameters, optionally with logging. + */ + protected def transformSchema( + schema: StructType, + paramMap: ParamMap, + logging: Boolean): StructType = { + if (logging) { + logDebug(s"Input schema: ${schema.json}") + } + val outputSchema = transformSchema(schema, paramMap) + if (logging) { + logDebug(s"Expected output schema: ${outputSchema.json}") + } + outputSchema + } +} + +/** + * :: AlphaComponent :: + * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each + * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the + * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will + * be called on the input dataset to fit a model. Then the model, which is a transformer, will be + * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]], + * its [[Transformer.transform]] method will be called to produce the dataset for the next stage. + * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and + * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as + * an identity transformer. + */ +@AlphaComponent +class Pipeline extends Estimator[PipelineModel] { + + /** param for pipeline stages */ + val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") + def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def getStages: Array[PipelineStage] = get(stages) + + /** + * Fits the pipeline to the input dataset with additional parameters. If a stage is an + * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model. + * Then the model, which is a transformer, will be used to transform the dataset as the input to + * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be + * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an + * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the + * pipeline stages. If there are no stages, the output model acts as an identity transformer. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted pipeline + */ + override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val theStages = map(stages) + // Search for the last estimator. + var indexOfLastEstimator = -1 + theStages.view.zipWithIndex.foreach { case (stage, index) => + stage match { + case _: Estimator[_] => + indexOfLastEstimator = index + case _ => + } + } + var curDataset = dataset + val transformers = ListBuffer.empty[Transformer] + theStages.view.zipWithIndex.foreach { case (stage, index) => + if (index <= indexOfLastEstimator) { + val transformer = stage match { + case estimator: Estimator[_] => + estimator.fit(curDataset, paramMap) + case t: Transformer => + t + case _ => + throw new IllegalArgumentException( + s"Do not support stage $stage of type ${stage.getClass}") + } + curDataset = transformer.transform(curDataset, paramMap) + transformers += transformer + } else { + transformers += stage.asInstanceOf[Transformer] + } + } + + new PipelineModel(this, map, transformers.toArray) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val theStages = map(stages) + require(theStages.toSet.size == theStages.size, + "Cannot have duplicate components in a pipeline.") + theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap)) + } +} + +/** + * :: AlphaComponent :: + * Represents a compiled pipeline. + */ +@AlphaComponent +class PipelineModel private[ml] ( + override val parent: Pipeline, + override val fittingParamMap: ParamMap, + private[ml] val stages: Array[Transformer]) + extends Model[PipelineModel] with Logging { + + /** + * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input + * estimator does not exist in the pipeline. + */ + def getModel[M <: Model[M]](stage: Estimator[M]): M = { + val matched = stages.filter { + case m: Model[_] => m.parent.eq(stage) + case _ => false + } + if (matched.isEmpty) { + throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.") + } else if (matched.size > 1) { + throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.") + } else { + matched.head.asInstanceOf[M] + } + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap)) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala new file mode 100644 index 0000000000000..490e6609ad311 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.types._ + +/** + * :: AlphaComponent :: + * Abstract class for transformers that transform one dataset into another. + */ +@AlphaComponent +abstract class Transformer extends PipelineStage with Params { + + /** + * Transforms the dataset with optional parameters + * @param dataset input dataset + * @param paramPairs optional list of param pairs, overwrite embedded params + * @return transformed dataset + */ + @varargs + def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + val map = new ParamMap() + paramPairs.foreach(map.put(_)) + transform(dataset, map) + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + + // Java-friendly versions of transform. + + /** + * Transforms the dataset with optional parameters. + * @param dataset input datset + * @param paramPairs optional list of param pairs, overwrite embedded params + * @return transformed dataset + */ + @varargs + def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = { + transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = { + transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD + } +} + +/** + * Abstract class for transformers that take one input column, apply transformation, and output the + * result as a new column. + */ +private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]] + extends Transformer with HasInputCol with HasOutputCol with Logging { + + def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] + + /** + * Creates the transform function using the given param map. The input param map already takes + * account of the embedded param map. So the param values should be determined solely by the input + * param map. + */ + protected def createTransformFunc(paramMap: ParamMap): IN => OUT + + /** + * Validates the input type. Throw an exception if it is invalid. + */ + protected def validateInputType(inputType: DataType): Unit = {} + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + validateInputType(inputType) + if (schema.fieldNames.contains(map(outputCol))) { + throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") + } + val output = ScalaReflection.schemaFor[OUT] + val outputFields = schema.fields :+ + StructField(map(outputCol), output.dataType, output.nullable) + StructType(outputFields) + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val udf = this.createTransformFunc(map) + dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala new file mode 100644 index 0000000000000..85b8899636ca5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.storage.StorageLevel + +/** + * :: AlphaComponent :: + * Params for logistic regression. + */ +@AlphaComponent +private[classification] trait LogisticRegressionParams extends Params + with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol + with HasScoreCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param paramMap additional parameters + * @param fitting whether this is in fitting + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean): StructType = { + val map = this.paramMap ++ paramMap + val featuresType = schema(map(featuresCol)).dataType + // TODO: Support casting Array[Double] and Array[Float] to Vector. + require(featuresType.isInstanceOf[VectorUDT], + s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") + if (fitting) { + val labelType = schema(map(labelCol)).dataType + require(labelType == DoubleType, + s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") + } + val fieldNames = schema.fieldNames + require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") + require(!fieldNames.contains(map(predictionCol)), + s"Prediction column ${map(predictionCol)} already exists.") + val outputFields = schema.fields ++ Seq( + StructField(map(scoreCol), DoubleType, false), + StructField(map(predictionCol), DoubleType, false)) + StructType(outputFields) + } +} + +/** + * Logistic regression. + */ +class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { + + setRegParam(0.1) + setMaxIter(100) + setThreshold(0.5) + + def setRegParam(value: Double): this.type = set(regParam, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + def setLabelCol(value: String): this.type = set(labelCol, value) + def setThreshold(value: Double): this.type = set(threshold, value) + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + }.persist(StorageLevel.MEMORY_AND_DISK) + val lr = new LogisticRegressionWithLBFGS + lr.optimizer + .setRegParam(map(regParam)) + .setNumIterations(map(maxIter)) + val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) + instances.unpersist() + // copy model params + Params.inheritValues(map, this, lrm) + lrm + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = true) + } +} + +/** + * :: AlphaComponent :: + * Model produced by [[LogisticRegression]]. + */ +@AlphaComponent +class LogisticRegressionModel private[ml] ( + override val parent: LogisticRegression, + override val fittingParamMap: ParamMap, + weights: Vector) + extends Model[LogisticRegressionModel] with LogisticRegressionParams { + + def setThreshold(value: Double): this.type = set(threshold, value) + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = false) + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val score: Vector => Double = (v) => { + val margin = BLAS.dot(v, weights) + 1.0 / (1.0 + math.exp(-margin)) + } + val t = map(threshold) + val predict: Double => Double = (score) => { + if (score > t) 1.0 else 0.0 + } + dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) + .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala new file mode 100644 index 0000000000000..0b0504e036ec9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.sql.{DoubleType, Row, SchemaRDD} + +/** + * :: AlphaComponent :: + * Evaluator for binary classification, which expects two input columns: score and label. + */ +@AlphaComponent +class BinaryClassificationEvaluator extends Evaluator with Params + with HasScoreCol with HasLabelCol { + + /** param for metric name in evaluation */ + val metricName: Param[String] = new Param(this, "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + def getMetricName: String = get(metricName) + def setMetricName(value: String): this.type = set(metricName, value) + + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + val map = this.paramMap ++ paramMap + + val schema = dataset.schema + val scoreType = schema(map(scoreCol)).dataType + require(scoreType == DoubleType, + s"Score column ${map(scoreCol)} must be double type but found $scoreType") + val labelType = schema(map(labelCol)).dataType + require(labelType == DoubleType, + s"Label column ${map(labelCol)} must be double type but found $labelType") + + import dataset.sqlContext._ + val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + .map { case Row(score: Double, label: Double) => + (score, label) + } + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val metric = map(metricName) match { + case "areaUnderROC" => + metrics.areaUnderROC() + case "areaUnderPR" => + metrics.areaUnderPR() + case other => + throw new IllegalArgumentException(s"Does not support metric $other.") + } + metrics.unpersist() + metric + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala new file mode 100644 index 0000000000000..b98b1755a3584 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.Vector + +/** + * :: AlphaComponent :: + * Maps a sequence of terms to their term frequencies using the hashing trick. + */ +@AlphaComponent +class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { + + /** number of features */ + val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + def setNumFeatures(value: Int) = set(numFeatures, value) + def getNumFeatures: Int = get(numFeatures) + + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { + val hashingTF = new feature.HashingTF(paramMap(numFeatures)) + hashingTF.transform + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala new file mode 100644 index 0000000000000..896a6b83b67bf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ + +/** + * Params for [[StandardScaler]] and [[StandardScalerModel]]. + */ +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol + +/** + * :: AlphaComponent :: + * Standardizes features by removing the mean and scaling to unit variance using column summary + * statistics on the samples in the training set. + */ +@AlphaComponent +class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + + def setInputCol(value: String): this.type = set(inputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val input = dataset.select(map(inputCol).attr) + .map { case Row(v: Vector) => + v + } + val scaler = new feature.StandardScaler().fit(input) + val model = new StandardScalerModel(this, map, scaler) + Params.inheritValues(map, this, model) + model + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StandardScaler]]. + */ +@AlphaComponent +class StandardScalerModel private[ml] ( + override val parent: StandardScaler, + override val fittingParamMap: ParamMap, + scaler: feature.StandardScalerModel) + extends Model[StandardScalerModel] with StandardScalerParams { + + def setInputCol(value: String): this.type = set(inputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val scale: (Vector) => Vector = (v) => { + scaler.transform(v) + } + dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala new file mode 100644 index 0000000000000..0a6599b64c011 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.{DataType, StringType} + +/** + * :: AlphaComponent :: + * A tokenizer that converts the input string to lowercase and then splits it by white spaces. + */ +@AlphaComponent +class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { + + protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + _.toLowerCase.split("\\s") + } + + protected override def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java new file mode 100644 index 0000000000000..00d9c802e930d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * assemble and configure practical machine learning pipelines. + */ +@AlphaComponent +package org.apache.spark.ml; + +import org.apache.spark.annotation.AlphaComponent; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala new file mode 100644 index 0000000000000..51cd48c90432a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * assemble and configure practical machine learning pipelines. + */ +package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala new file mode 100644 index 0000000000000..8fd46aef4b99d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import java.lang.reflect.Modifier + +import org.apache.spark.annotation.AlphaComponent + +import scala.annotation.varargs +import scala.collection.mutable + +import org.apache.spark.ml.Identifiable + +/** + * :: AlphaComponent :: + * A param with self-contained documentation and optionally default value. Primitive-typed param + * should use the specialized versions, which are more friendly to Java users. + * + * @param parent parent object + * @param name param name + * @param doc documentation + * @tparam T param value type + */ +@AlphaComponent +class Param[T] ( + val parent: Params, + val name: String, + val doc: String, + val defaultValue: Option[T] = None) + extends Serializable { + + /** + * Creates a param pair with the given value (for Java). + */ + def w(value: T): ParamPair[T] = this -> value + + /** + * Creates a param pair with the given value (for Scala). + */ + def ->(value: T): ParamPair[T] = ParamPair(this, value) + + override def toString: String = { + if (defaultValue.isDefined) { + s"$name: $doc (default: ${defaultValue.get})" + } else { + s"$name: $doc" + } + } +} + +// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... + +/** Specialized version of [[Param[Double]]] for Java. */ +class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) + extends Param[Double](parent, name, doc, defaultValue) { + + override def w(value: Double): ParamPair[Double] = super.w(value) +} + +/** Specialized version of [[Param[Int]]] for Java. */ +class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) + extends Param[Int](parent, name, doc, defaultValue) { + + override def w(value: Int): ParamPair[Int] = super.w(value) +} + +/** Specialized version of [[Param[Float]]] for Java. */ +class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) + extends Param[Float](parent, name, doc, defaultValue) { + + override def w(value: Float): ParamPair[Float] = super.w(value) +} + +/** Specialized version of [[Param[Long]]] for Java. */ +class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) + extends Param[Long](parent, name, doc, defaultValue) { + + override def w(value: Long): ParamPair[Long] = super.w(value) +} + +/** Specialized version of [[Param[Boolean]]] for Java. */ +class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) + extends Param[Boolean](parent, name, doc, defaultValue) { + + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) +} + +/** + * A param amd its value. + */ +case class ParamPair[T](param: Param[T], value: T) + +/** + * :: AlphaComponent :: + * Trait for components that take parameters. This also provides an internal param map to store + * parameter values attached to the instance. + */ +@AlphaComponent +trait Params extends Identifiable with Serializable { + + /** Returns all params. */ + def params: Array[Param[_]] = { + val methods = this.getClass.getMethods + methods.filter { m => + Modifier.isPublic(m.getModifiers) && + classOf[Param[_]].isAssignableFrom(m.getReturnType) && + m.getParameterTypes.isEmpty + }.sortBy(_.getName) + .map(m => m.invoke(this).asInstanceOf[Param[_]]) + } + + /** + * Validates parameter values stored internally plus the input parameter map. + * Raises an exception if any parameter is invalid. + */ + def validate(paramMap: ParamMap): Unit = {} + + /** + * Validates parameter values stored internally. + * Raise an exception if any parameter value is invalid. + */ + def validate(): Unit = validate(ParamMap.empty) + + /** + * Returns the documentation of all params. + */ + def explainParams(): String = params.mkString("\n") + + /** Checks whether a param is explicitly set. */ + def isSet(param: Param[_]): Boolean = { + require(param.parent.eq(this)) + paramMap.contains(param) + } + + /** Gets a param by its name. */ + private[ml] def getParam(paramName: String): Param[Any] = { + val m = this.getClass.getMethod(paramName) + assert(Modifier.isPublic(m.getModifiers) && + classOf[Param[_]].isAssignableFrom(m.getReturnType) && + m.getParameterTypes.isEmpty) + m.invoke(this).asInstanceOf[Param[Any]] + } + + /** + * Sets a parameter in the embedded param map. + */ + private[ml] def set[T](param: Param[T], value: T): this.type = { + require(param.parent.eq(this)) + paramMap.put(param.asInstanceOf[Param[Any]], value) + this + } + + /** + * Gets the value of a parameter in the embedded param map. + */ + private[ml] def get[T](param: Param[T]): T = { + require(param.parent.eq(this)) + paramMap(param) + } + + /** + * Internal param map. + */ + protected val paramMap: ParamMap = ParamMap.empty +} + +private[ml] object Params { + + /** + * Copies parameter values from the parent estimator to the child model it produced. + * @param paramMap the param map that holds parameters of the parent + * @param parent the parent estimator + * @param child the child model + */ + def inheritValues[E <: Params, M <: E]( + paramMap: ParamMap, + parent: E, + child: M): Unit = { + parent.params.foreach { param => + if (paramMap.contains(param)) { + child.set(child.getParam(param.name), paramMap(param)) + } + } + } +} + +/** + * :: AlphaComponent :: + * A param to value map. + */ +@AlphaComponent +class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { + + /** + * Creates an empty param map. + */ + def this() = this(mutable.Map.empty[Param[Any], Any]) + + /** + * Puts a (param, value) pair (overwrites if the input param exists). + */ + def put[T](param: Param[T], value: T): this.type = { + map(param.asInstanceOf[Param[Any]]) = value + this + } + + /** + * Puts a list of param pairs (overwrites if the input params exists). + */ + def put(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + put(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } + + /** + * Optionally returns the value associated with a param or its default. + */ + def get[T](param: Param[T]): Option[T] = { + map.get(param.asInstanceOf[Param[Any]]) + .orElse(param.defaultValue) + .asInstanceOf[Option[T]] + } + + /** + * Gets the value of the input param or its default value if it does not exist. + * Raises a NoSuchElementException if there is no value associated with the input param. + */ + def apply[T](param: Param[T]): T = { + val value = get(param) + if (value.isDefined) { + value.get + } else { + throw new NoSuchElementException(s"Cannot find param ${param.name}.") + } + } + + /** + * Checks whether a parameter is explicitly specified. + */ + def contains(param: Param[_]): Boolean = { + map.contains(param.asInstanceOf[Param[Any]]) + } + + /** + * Filters this param map for the given parent. + */ + def filter(parent: Params): ParamMap = { + val filtered = map.filterKeys(_.parent == parent) + new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + } + + /** + * Make a copy of this param map. + */ + def copy: ParamMap = new ParamMap(map.clone()) + + override def toString: String = { + map.map { case (param, value) => + s"\t${param.parent.uid}-${param.name}: $value" + }.mkString("{\n", ",\n", "\n}") + } + + /** + * Returns a new param map that contains parameters in this map and the given map, + * where the latter overwrites this if there exists conflicts. + */ + def ++(other: ParamMap): ParamMap = { + new ParamMap(this.map ++ other.map) + } + + + /** + * Adds all parameters from the input param map into this param map. + */ + def ++=(other: ParamMap): this.type = { + this.map ++= other.map + this + } + + /** + * Converts this param map to a sequence of param pairs. + */ + def toSeq: Seq[ParamPair[_]] = { + map.toSeq.map { case (param, value) => + ParamPair(param, value) + } + } +} + +object ParamMap { + + /** + * Returns an empty param map. + */ + def empty: ParamMap = new ParamMap() + + /** + * Constructs a param map by specifying its entries. + */ + @varargs + def apply(paramPairs: ParamPair[_]*): ParamMap = { + new ParamMap().put(paramPairs: _*) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala new file mode 100644 index 0000000000000..ef141d3eb2b06 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +private[ml] trait HasRegParam extends Params { + /** param for regularization parameter */ + val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + def getRegParam: Double = get(regParam) +} + +private[ml] trait HasMaxIter extends Params { + /** param for max number of iterations */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = get(maxIter) +} + +private[ml] trait HasFeaturesCol extends Params { + /** param for features column name */ + val featuresCol: Param[String] = + new Param(this, "featuresCol", "features column name", Some("features")) + def getFeaturesCol: String = get(featuresCol) +} + +private[ml] trait HasLabelCol extends Params { + /** param for label column name */ + val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) + def getLabelCol: String = get(labelCol) +} + +private[ml] trait HasScoreCol extends Params { + /** param for score column name */ + val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) + def getScoreCol: String = get(scoreCol) +} + +private[ml] trait HasPredictionCol extends Params { + /** param for prediction column name */ + val predictionCol: Param[String] = + new Param(this, "predictionCol", "prediction column name", Some("prediction")) + def getPredictionCol: String = get(predictionCol) +} + +private[ml] trait HasThreshold extends Params { + /** param for threshold in (binary) prediction */ + val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + def getThreshold: Double = get(threshold) +} + +private[ml] trait HasInputCol extends Params { + /** param for input column name */ + val inputCol: Param[String] = new Param(this, "inputCol", "input column name") + def getInputCol: String = get(inputCol) +} + +private[ml] trait HasOutputCol extends Params { + /** param for output column name */ + val outputCol: Param[String] = new Param(this, "outputCol", "output column name") + def getOutputCol: String = get(outputCol) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala new file mode 100644 index 0000000000000..194b9bfd9a9e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import com.github.fommil.netlib.F2jBLAS + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{SchemaRDD, StructType} + +/** + * Params for [[CrossValidator]] and [[CrossValidatorModel]]. + */ +private[ml] trait CrossValidatorParams extends Params { + /** param for the estimator to be cross-validated */ + val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + def getEstimator: Estimator[_] = get(estimator) + + /** param for estimator param maps */ + val estimatorParamMaps: Param[Array[ParamMap]] = + new Param(this, "estimatorParamMaps", "param maps for the estimator") + def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + + /** param for the evaluator for selection */ + val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + def getEvaluator: Evaluator = get(evaluator) + + /** param for number of folds for cross validation */ + val numFolds: IntParam = + new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + def getNumFolds: Int = get(numFolds) +} + +/** + * :: AlphaComponent :: + * K-fold cross validation. + */ +@AlphaComponent +class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { + + private val f2jBLAS = new F2jBLAS + + def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + def setNumFolds(value: Int): this.type = set(numFolds, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + val map = this.paramMap ++ paramMap + val schema = dataset.schema + transformSchema(dataset.schema, paramMap, logging = true) + val sqlCtx = dataset.sqlContext + val est = map(estimator) + val eval = map(evaluator) + val epm = map(estimatorParamMaps) + val numModels = epm.size + val metrics = new Array[Double](epm.size) + val splits = MLUtils.kFold(dataset, map(numFolds), 0) + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => + val trainingDataset = sqlCtx.applySchema(training, schema).cache() + val validationDataset = sqlCtx.applySchema(validation, schema).cache() + // multi-model training + logDebug(s"Train split $splitIndex with multiple sets of parameters.") + val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + var i = 0 + while (i < numModels) { + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) + logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + metrics(i) += metric + i += 1 + } + } + f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) + logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best cross-validation metric: $bestMetric.") + val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + val cvModel = new CrossValidatorModel(this, map, bestModel) + Params.inheritValues(map, this, cvModel) + cvModel + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + map(estimator).transformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model from k-fold cross validation. + */ +@AlphaComponent +class CrossValidatorModel private[ml] ( + override val parent: CrossValidator, + override val fittingParamMap: ParamMap, + val bestModel: Model[_]) + extends Model[CrossValidatorModel] with CrossValidatorParams { + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + bestModel.transform(dataset, paramMap) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + bestModel.transformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala new file mode 100644 index 0000000000000..dafe73d82c00a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.annotation.varargs +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ + +/** + * :: AlphaComponent :: + * Builder for a param grid used in grid search-based model selection. + */ +@AlphaComponent +class ParamGridBuilder { + + private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] + + /** + * Sets the given parameters in this grid to fixed values. + */ + def baseOn(paramMap: ParamMap): this.type = { + baseOn(paramMap.toSeq: _*) + this + } + + /** + * Sets the given parameters in this grid to fixed values. + */ + @varargs + def baseOn(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + addGrid(p.param.asInstanceOf[Param[Any]], Seq(p.value)) + } + this + } + + /** + * Adds a param with multiple values (overwrites if the input param exists). + */ + def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { + paramGrid.put(param, values) + this + } + + // specialized versions of addGrid for Java. + + /** + * Adds a double param with multiple values. + */ + def addGrid(param: DoubleParam, values: Array[Double]): this.type = { + addGrid[Double](param, values) + } + + /** + * Adds a int param with multiple values. + */ + def addGrid(param: IntParam, values: Array[Int]): this.type = { + addGrid[Int](param, values) + } + + /** + * Adds a float param with multiple values. + */ + def addGrid(param: FloatParam, values: Array[Float]): this.type = { + addGrid[Float](param, values) + } + + /** + * Adds a long param with multiple values. + */ + def addGrid(param: LongParam, values: Array[Long]): this.type = { + addGrid[Long](param, values) + } + + /** + * Adds a boolean param with true and false. + */ + def addGrid(param: BooleanParam): this.type = { + addGrid[Boolean](param, Array(true, false)) + } + + /** + * Builds and returns all combinations of parameters specified by the param grid. + */ + def build(): Array[ParamMap] = { + var paramMaps = Array(new ParamMap) + paramGrid.foreach { case (param, values) => + val newParamMaps = values.flatMap { v => + paramMaps.map(_.copy.put(param.asInstanceOf[Param[Any]], v)) + } + paramMaps = newParamMaps.toArray + } + paramMaps + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 70d7138e3060f..9f20cd5d00dcd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream +import java.nio.{ByteBuffer, ByteOrder} import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -28,22 +29,22 @@ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ -import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.test.ChiSqTestResult +import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -74,10 +75,28 @@ class PythonMLLibAPI extends Serializable { learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], initialWeights: Vector): JList[Object] = { - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - learner.disableUncachedWarning() - val model = learner.run(data.rdd, initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + try { + val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + + /** + * Return the Updater from string + */ + def getUpdaterFromString(regType: String): Updater = { + if (regType == "l2") { + new SquaredL2Updater + } else if (regType == "l1") { + new L1Updater + } else if (regType == null || regType == "none") { + new SimpleUpdater + } else { + throw new IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: ['l1', 'l2', None].") + } } /** @@ -99,14 +118,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - lrAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - lrAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") - } + lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, data, @@ -176,14 +188,7 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - SVMAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - SVMAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") - } + SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, data, @@ -209,14 +214,33 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") - } + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) + trainRegressionModel( + LogRegAlg, + data, + initialWeights) + } + + /** + * Java stub for Python mllib LogisticRegressionWithLBFGS.train() + */ + def trainLogisticRegressionModelWithLBFGS( + data: JavaRDD[LabeledPoint], + numIterations: Int, + initialWeights: Vector, + regParam: Double, + regType: String, + intercept: Boolean, + corrections: Int, + tolerance: Double): JList[Object] = { + val LogRegAlg = new LogisticRegressionWithLBFGS() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setNumCorrections(corrections) + .setConvergenceTol(tolerance) + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, data, @@ -248,9 +272,11 @@ class PythonMLLibAPI extends Serializable { .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - .disableUncachedWarning() - return kMeansAlg.run(data.rdd) + try { + kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) + } finally { + data.rdd.unpersist(blocking = false) + } } /** @@ -384,16 +410,18 @@ class PythonMLLibAPI extends Serializable { numPartitions: Int, numIterations: Int, seed: Long): Word2VecModelWrapper = { - val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) - val model = word2vec.fit(data) - data.unpersist() - new Word2VecModelWrapper(model) + try { + val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) + new Word2VecModelWrapper(model) + } finally { + dataJRDD.rdd.unpersist(blocking = false) + } } private[python] class Word2VecModelWrapper(model: Word2VecModel) { @@ -454,8 +482,50 @@ class PythonMLLibAPI extends Serializable { categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) + try { + DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy) + } finally { + data.rdd.unpersist(blocking = false) + } + } - DecisionTree.train(data.rdd, strategy) + /** + * Java stub for Python mllib RandomForest.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainRandomForestModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfo: JMap[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurityStr: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap) + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + if (algo == Algo.Classification) { + RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed) + } else { + RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed) + } + } finally { + cached.unpersist(blocking = false) + } } /** @@ -615,6 +685,7 @@ class PythonMLLibAPI extends Serializable { private[spark] object SerDe extends Serializable { val PYSPARK_PACKAGE = "pyspark.mllib" + val LATIN1 = "ISO-8859-1" /** * Base class used for pickle @@ -636,7 +707,7 @@ private[spark] object SerDe extends Serializable { def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { if (obj == this) { out.write(Opcodes.GLOBAL) - out.write((module + "\n" + name + "\n").getBytes()) + out.write((module + "\n" + name + "\n").getBytes) } else { pickler.save(this) // it will be memorized by Pickler saveState(obj, out, pickler) @@ -666,7 +737,16 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val vector: DenseVector = obj.asInstanceOf[DenseVector] - saveObjects(out, pickler, vector.toArray) + val bytes = new Array[Byte](8 * vector.size) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val db = bb.asDoubleBuffer() + db.put(vector.values) + + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(bytes.length)) + out.write(bytes) + out.write(Opcodes.TUPLE1) } def construct(args: Array[Object]): Object = { @@ -674,7 +754,13 @@ private[spark] object SerDe extends Serializable { if (args.length != 1) { throw new PickleException("should be 1") } - new DenseVector(args(0).asInstanceOf[Array[Double]]) + val bytes = args(0).asInstanceOf[String].getBytes(LATIN1) + val bb = ByteBuffer.wrap(bytes, 0, bytes.length) + bb.order(ByteOrder.nativeOrder()) + val db = bb.asDoubleBuffer() + val ans = new Array[Double](bytes.length / 8) + db.get(ans) + Vectors.dense(ans) } } @@ -683,15 +769,30 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] - saveObjects(out, pickler, m.numRows, m.numCols, m.values) + val bytes = new Array[Byte](8 * m.values.size) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(m.numRows)) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(m.numCols)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(bytes.length)) + out.write(bytes) + out.write(Opcodes.TUPLE3) } def construct(args: Array[Object]): Object = { if (args.length != 3) { throw new PickleException("should be 3") } - new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], - args(2).asInstanceOf[Array[Double]]) + val bytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val n = bytes.length / 8 + val values = new Array[Double](n) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values) } } @@ -700,15 +801,40 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val v: SparseVector = obj.asInstanceOf[SparseVector] - saveObjects(out, pickler, v.size, v.indices, v.values) + val n = v.indices.size + val indiceBytes = new Array[Byte](4 * n) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices) + val valueBytes = new Array[Byte](8 * n) + ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values) + + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(v.size)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(indiceBytes.length)) + out.write(indiceBytes) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(valueBytes.length)) + out.write(valueBytes) + out.write(Opcodes.TUPLE3) } def construct(args: Array[Object]): Object = { if (args.length != 3) { throw new PickleException("should be 3") } - new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]], - args(2).asInstanceOf[Array[Double]]) + val size = args(0).asInstanceOf[Int] + val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1) + val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val n = indiceBytes.length / 4 + val indices = new Array[Int](n) + val values = new Array[Double](n) + if (n > 0) { + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices) + ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values) + } + new SparseVector(size, indices, values) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 84d3c7cebd7c8..94d757bc317ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -64,16 +64,17 @@ class LogisticRegressionModel ( val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { - case Some(t) => if (score < t) 0.0 else 1.0 + case Some(t) => if (score > t) 1.0 else 0.0 case None => score } } } /** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * + * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By + * default L2 regularization is used, which can be changed via + * [[LogisticRegressionWithSGD.optimizer]]. + * NOTE: Labels used in Logistic Regression should be {0, 1}. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( @@ -93,9 +94,10 @@ class LogisticRegressionWithSGD private ( override protected val validators = List(DataValidators.binaryLabelValidator) /** - * Construct a LogisticRegression object with default parameters + * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, + * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 0.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 80f8a1b2f1e84..dd514ff8a37f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -65,14 +65,15 @@ class SVMModel ( intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept threshold match { - case Some(t) => if (margin < t) 0.0 else 1.0 + case Some(t) => if (margin > t) 1.0 else 0.0 case None => margin } } } /** - * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. + * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2 + * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * NOTE: Labels used in SVM should be {0, 1}. */ class SVMWithSGD private ( @@ -92,9 +93,10 @@ class SVMWithSGD private ( override protected val validators = List(DataValidators.binaryLabelValidator) /** - * Construct a SVM object with default parameters + * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, + * regParm: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new SVMModel(weights, intercept) @@ -185,6 +187,6 @@ object SVMWithSGD { * @return a SVMModel which has the weights and offset from training. */ def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 7443f232ec3e7..34ea0de706f08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -113,22 +113,13 @@ class KMeans private ( this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ def run(data: RDD[Vector]): KMeansModel = { - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -143,7 +134,7 @@ class KMeans private ( norms.unpersist() // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala index 562663ad36b40..be3319d60ce25 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -24,26 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl def apply(c: BinaryConfusionMatrix): Double } -/** Precision. */ +/** Precision. Defined as 1.0 when there are no positive examples. */ private[evaluation] object Precision extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives) + override def apply(c: BinaryConfusionMatrix): Double = { + val totalPositives = c.numTruePositives + c.numFalsePositives + if (totalPositives == 0) { + 1.0 + } else { + c.numTruePositives.toDouble / totalPositives + } + } } -/** False positive rate. */ +/** False positive rate. Defined as 0.0 when there are no negative examples. */ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numFalsePositives.toDouble / c.numNegatives + override def apply(c: BinaryConfusionMatrix): Double = { + if (c.numNegatives == 0) { + 0.0 + } else { + c.numFalsePositives.toDouble / c.numNegatives + } + } } -/** Recall. */ +/** Recall. Defined as 0.0 when there are no positive examples. */ private[evaluation] object Recall extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numTruePositives.toDouble / c.numPositives + override def apply(c: BinaryConfusionMatrix): Double = { + if (c.numPositives == 0) { + 0.0 + } else { + c.numTruePositives.toDouble / c.numPositives + } + } } /** - * F-Measure. + * F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples + * are false positives. * @param beta the beta constant in F-Measure * @see http://en.wikipedia.org/wiki/F1_score */ @@ -52,6 +69,10 @@ private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificati override def apply(c: BinaryConfusionMatrix): Double = { val precision = Precision(c) val recall = Recall(c) - (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + if (precision + recall == 0) { + 0.0 + } else { + (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index dfad25d57c947..a9c2e23717896 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm} +import breeze.linalg.{norm => brzNorm} import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** * :: Experimental :: @@ -47,22 +47,31 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - var norm = brzNorm(vector.toBreeze, p) + val norm = brzNorm(vector.toBreeze, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. // However, for sparse vector, the `index` array will not be changed, // so we can re-use it to save memory. - vector.toBreeze match { - case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm) - case sv: BSV[Double] => - val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size var i = 0 - while (i < output.data.length) { - output.data(i) /= norm + while (i < size) { + values(i) /= norm i += 1 } - Vectors.fromBreeze(output) + Vectors.dense(values) + case sv: SparseVector => + val values = sv.values.clone() + val nnz = values.size + var i = 0 + while (i < nnz) { + values(i) /= norm + i += 1 + } + Vectors.sparse(sv.size, sv.indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 4dfd1f0ab8134..8c4c5db5258d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} - import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] ( require(mean.size == variance.size) - private lazy val factor: BDV[Double] = { - val f = BDV.zeros[Double](variance.size) + private lazy val factor: Array[Double] = { + val f = Array.ofDim[Double](variance.size) var i = 0 while (i < f.size) { f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 @@ -87,6 +85,11 @@ class StandardScalerModel private[mllib] ( f } + // Since `shift` will be only used in `withMean` branch, we have it as + // `lazy val` so it will be evaluated in that branch. Note that we don't + // want to create this array multiple times in `transform` function. + private lazy val shift: Array[Double] = mean.toArray + /** * Applies standardization transformation on a vector. * @@ -97,30 +100,57 @@ class StandardScalerModel private[mllib] ( override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { - vector.toBreeze match { - case dv: BDV[Double] => - val output = vector.toBreeze.copy - var i = 0 - while (i < output.length) { - output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0) - i += 1 + // By default, Scala generates Java methods for member variables. So every time when + // the member variables are accessed, `invokespecial` will be called which is expensive. + // This can be avoid by having a local reference of `shift`. + val localShift = shift + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size + if (withStd) { + // Having a local reference of `factor` to avoid overhead as the comment before. + val localFactor = factor + var i = 0 + while (i < size) { + values(i) = (values(i) - localShift(i)) * localFactor(i) + i += 1 + } + } else { + var i = 0 + while (i < size) { + values(i) -= localShift(i) + i += 1 + } } - Vectors.fromBreeze(output) + Vectors.dense(values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else if (withStd) { - vector.toBreeze match { - case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor) - case sv: BSV[Double] => + // Having a local reference of `factor` to avoid overhead as the comment before. + val localFactor = factor + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size + var i = 0 + while(i < size) { + values(i) *= localFactor(i) + i += 1 + } + Vectors.dense(values) + case sv: SparseVector => // For sparse vector, the `index` array inside sparse vector object will not be changed, // so we can re-use it to save memory. - val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + val indices = sv.indices + val values = sv.values.clone() + val nnz = values.size var i = 0 - while (i < output.data.length) { - output.data(i) *= factor(output.index(i)) + while (i < nnz) { + values(i) *= localFactor(indices(i)) i += 1 } - Vectors.fromBreeze(output) + Vectors.sparse(sv.size, indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f5f7ad613d4c4..7960f3cab576f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -461,4 +461,11 @@ class Word2VecModel private[mllib] ( .tail .toArray } + + /** + * Returns a map of words to their vector representations. + */ + def getVectors: Map[String, Array[Float]] = { + model + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 54ee930d61003..89539e600f48c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -25,7 +25,7 @@ import org.apache.spark.Logging /** * BLAS routines for MLlib's vectors and matrices. */ -private[mllib] object BLAS extends Serializable with Logging { +private[spark] object BLAS extends Serializable with Logging { @transient private var _f2jBLAS: NetlibBLAS = _ @transient private var _nativeBLAS: NetlibBLAS = _ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 2cc52e94282ba..327366a1a3a82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -17,12 +17,10 @@ package org.apache.spark.mllib.linalg -import java.util.Arrays +import java.util.{Random, Arrays} import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} -import org.apache.spark.util.random.XORShiftRandom - /** * Trait for a local matrix. */ @@ -67,14 +65,14 @@ sealed trait Matrix extends Serializable { } /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - def transposeMultiply(y: DenseMatrix): DenseMatrix = { + private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] BLAS.gemm(true, false, 1.0, this, y, 0.0, C) C } /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - def transposeMultiply(y: DenseVector): DenseVector = { + private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { val output = new DenseVector(new Array[Double](numCols)) BLAS.gemv(true, 1.0, this, y, 0.0, output) output @@ -291,22 +289,22 @@ object Matrices { * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def rand(numRows: Int, numCols: Int): Matrix = { - val rand = new XORShiftRandom - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble())) + def rand(numRows: Int, numCols: Int, rng: Random): Matrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } /** * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def randn(numRows: Int, numCols: Int): Matrix = { - val rand = new XORShiftRandom - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian())) + def randn(numRows: Int, numCols: Int, rng: Random): Matrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index ac217edc619ab..c6d5fe5bc678c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -76,6 +76,15 @@ sealed trait Vector extends Serializable { def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } + + /** + * Applies a function `f` to all the active elements of dense and sparse vector. + * + * @param f the function takes two parameters where the first parameter is the index of + * the vector with type `Int`, and the second parameter is the corresponding value + * with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Double) => Unit) } /** @@ -115,6 +124,9 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def deserialize(datum: Any): Vector = { datum match { + // TODO: something wrong with UDT serialization + case v: Vector => + v case row: Row => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") @@ -234,7 +246,7 @@ object Vectors { private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => - if (v.offset == 0 && v.stride == 1) { + if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { new DenseVector(v.data) } else { new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one @@ -270,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector { override def copy: DenseVector = { new DenseVector(values.clone()) } + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localValues = values + + while (i < localValuesSize) { + f(i, localValues(i)) + i += 1 + } + } } /** @@ -306,4 +329,16 @@ class SparseVector( } private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localIndices = indices + val localValues = values + + while (i < localValuesSize) { + f(localIndices(i), localValues(i)) + i += 1 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index a6912056395d7..0857877951c82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -160,14 +160,15 @@ object GradientDescent extends Logging { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) val numExamples = data.count() - val miniBatchSize = numExamples * miniBatchFraction // if no data, return initial weights to avoid NaNs if (numExamples == 0) { - - logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found") + logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found") return (initialWeights, stochasticLossHistory.toArray) + } + if (numExamples * miniBatchFraction < 1) { + logWarning("The miniBatchFraction is too small") } // Initialize weights as a column vector @@ -185,25 +186,31 @@ object GradientDescent extends Logging { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .treeAggregate((BDV.zeros[Double](n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) - (grad, loss + l) + val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i) + .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))( + seqOp = (c, v) => { + // c: (grad, loss, count), v: (label, features) + val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1)) + (c._1, c._2 + l, c._3 + 1) }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + combOp = (c1, c2) => { + // c: (grad, loss, count) + (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3) }) - /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - stochasticLossHistory.append(lossSum / miniBatchSize + regVal) - val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam) - weights = update._1 - regVal = update._2 + if (miniBatchSize > 0) { + /** + * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + stochasticLossHistory.append(lossSum / miniBatchSize + regVal) + val update = updater.compute( + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam) + weights = update._1 + regVal = update._2 + } else { + logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") + } } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 84d192db53e26..90ac252226006 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.{abs, sqrt} -import scala.util.Random -import scala.util.Sorting +import scala.util.{Random, Sorting} import scala.util.hashing.byteswap32 import org.jblas.{DoubleMatrix, SimpleBlas, Solve} +import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.broadcast.Broadcast -import org.apache.spark.{Logging, HashPartitioner, Partitioner} -import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -import org.apache.spark.mllib.optimization.NNLS /** * Out-link information for a user or product block. This includes the original user/product IDs @@ -325,6 +325,11 @@ class ALS private ( new MatrixFactorizationModel(rank, usersOut, productsOut) } + /** + * Java-friendly version of [[ALS.run]]. + */ + def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) + /** * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors * for each user (or product), in a distributed fashion. @@ -741,7 +746,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @param alpha confidence parameter (only applies when immplicitPrefs = true) + * @param alpha confidence parameter * @param seed random seed */ def trainImplicit( @@ -768,7 +773,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @param alpha confidence parameter (only applies when immplicitPrefs = true) + * @param alpha confidence parameter */ def trainImplicit( ratings: RDD[Rating], @@ -792,6 +797,7 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) + * @param alpha confidence parameter */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 66b58ba770160..ed2f8b41bcae5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,27 +17,49 @@ package org.apache.spark.mllib.recommendation +import java.lang.{Integer => JavaInteger} + import org.jblas.DoubleMatrix -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.Logging +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.api.python.SerDe +import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * + * Note: If you create the model directly using constructor, please be aware that fast prediction + * requires cached user/product features and their associated partitioners. + * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. */ -class MatrixFactorizationModel private[mllib] ( +class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable { + val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + + require(rank > 0) + validateFeatures("User", userFeatures) + validateFeatures("Product", productFeatures) + + /** Validates factors and warns users if there are performance concerns. */ + private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = { + require(features.first()._2.size == rank, + s"$name feature dimension does not match the rank $rank.") + if (features.partitioner.isEmpty) { + logWarning(s"$name factor does not have a partitioner. " + + "Prediction on individual records could be slow.") + } + if (features.getStorageLevel == StorageLevel.NONE) { + logWarning(s"$name factor is not cached. Prediction could be slow.") + } + } + /** Predict the rating of one user for one product. */ def predict(user: Int, product: Int): Double = { val userVector = new DoubleMatrix(userFeatures.lookup(user).head) @@ -65,6 +87,13 @@ class MatrixFactorizationModel private[mllib] ( } } + /** + * Java-friendly version of [[MatrixFactorizationModel.predict]]. + */ + def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { + predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() + } + /** * Recommends products to a user. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 00dfc86c9e0bd..0287f04e2c777 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 17c753c56681f..2067b36f246b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.regression +import scala.beans.BeanInfo + import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -27,6 +29,7 @@ import org.apache.spark.SparkException * @param label Label for this data point. * @param features List of features for this data point. */ +@BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { "(%s,%s)".format(label, features) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index cb0d39e759a9f..f9791c6571782 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -67,9 +67,9 @@ class LassoWithSGD private ( /** * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 1.0, miniBatchFraction: 1.0}. + * regParam: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LassoModel(weights, intercept) @@ -161,6 +161,6 @@ object LassoWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int): LassoModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index a826deb695ee1..c8cad773f5efb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -68,9 +68,9 @@ class RidgeRegressionWithSGD private ( /** * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 1.0, miniBatchFraction: 1.0}. + * regParam: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new RidgeRegressionModel(weights, intercept) @@ -143,7 +143,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 1.0) + train(input, numIterations, stepSize, regParam, 0.01) } /** @@ -158,6 +158,6 @@ object RidgeRegressionWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int): RidgeRegressionModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index fab7c4405c65d..fcc2a148791bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,10 +17,8 @@ package org.apache.spark.mllib.stat -import breeze.linalg.{DenseVector => BDV} - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: @@ -40,14 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { private var n = 0 - private var currMean: BDV[Double] = _ - private var currM2n: BDV[Double] = _ - private var currM2: BDV[Double] = _ - private var currL1: BDV[Double] = _ + private var currMean: Array[Double] = _ + private var currM2n: Array[Double] = _ + private var currM2: Array[Double] = _ + private var currL1: Array[Double] = _ private var totalCnt: Long = 0 - private var nnz: BDV[Double] = _ - private var currMax: BDV[Double] = _ - private var currMin: BDV[Double] = _ + private var nnz: Array[Double] = _ + private var currMax: Array[Double] = _ + private var currMin: Array[Double] = _ /** * Add a new sample to this summarizer, and update the statistical summary. @@ -60,52 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(sample.size > 0, s"Vector should have dimension larger than zero.") n = sample.size - currMean = BDV.zeros[Double](n) - currM2n = BDV.zeros[Double](n) - currM2 = BDV.zeros[Double](n) - currL1 = BDV.zeros[Double](n) - nnz = BDV.zeros[Double](n) - currMax = BDV.fill(n)(Double.MinValue) - currMin = BDV.fill(n)(Double.MaxValue) + currMean = Array.ofDim[Double](n) + currM2n = Array.ofDim[Double](n) + currM2 = Array.ofDim[Double](n) + currL1 = Array.ofDim[Double](n) + nnz = Array.ofDim[Double](n) + currMax = Array.fill[Double](n)(Double.MinValue) + currMin = Array.fill[Double](n)(Double.MaxValue) } require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - @inline def update(i: Int, value: Double) = { + sample.foreachActive { (index, value) => if (value != 0.0) { - if (currMax(i) < value) { - currMax(i) = value + if (currMax(index) < value) { + currMax(index) = value } - if (currMin(i) > value) { - currMin(i) = value + if (currMin(index) > value) { + currMin(index) = value } - val tmpPrevMean = currMean(i) - currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) - currM2(i) += value * value - currL1(i) += math.abs(value) - - nnz(i) += 1.0 - } - } + val prevMean = currMean(index) + val diff = value - prevMean + currMean(index) = prevMean + diff / (nnz(index) + 1.0) + currM2n(index) += (value - currMean(index)) * diff + currM2(index) += value * value + currL1(index) += math.abs(value) - sample match { - case dv: DenseVector => { - var j = 0 - while (j < dv.size) { - update(j, dv.values(j)) - j += 1 - } + nnz(index) += 1.0 } - case sv: SparseVector => - var j = 0 - while (j < sv.indices.size) { - update(sv.indices(j), sv.values(j)) - j += 1 - } - case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } totalCnt += 1 @@ -124,47 +106,38 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt - val deltaMean: BDV[Double] = currMean - other.currMean var i = 0 while (i < n) { - // merge mean together - if (other.currMean(i) != 0.0) { - currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / - (nnz(i) + other.nnz(i)) - } - // merge m2n together - if (nnz(i) + other.nnz(i) != 0.0) { - currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / - (nnz(i) + other.nnz(i)) - } - // merge m2 together - if (nnz(i) + other.nnz(i) != 0.0) { + val thisNnz = nnz(i) + val otherNnz = other.nnz(i) + val totalNnz = thisNnz + otherNnz + if (totalNnz != 0.0) { + val deltaMean = other.currMean(i) - currMean(i) + // merge mean together + currMean(i) += deltaMean * otherNnz / totalNnz + // merge m2n together + currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz + // merge m2 together currM2(i) += other.currM2(i) - } - // merge l1 together - if (nnz(i) + other.nnz(i) != 0.0) { + // merge l1 together currL1(i) += other.currL1(i) + // merge max and min + currMax(i) = math.max(currMax(i), other.currMax(i)) + currMin(i) = math.min(currMin(i), other.currMin(i)) } - - if (currMax(i) < other.currMax(i)) { - currMax(i) = other.currMax(i) - } - if (currMin(i) > other.currMin(i)) { - currMin(i) = other.currMin(i) - } + nnz(i) = totalNnz i += 1 } - nnz += other.nnz } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n - this.currMean = other.currMean.copy - this.currM2n = other.currM2n.copy - this.currM2 = other.currM2.copy - this.currL1 = other.currL1.copy + this.currMean = other.currMean.clone + this.currM2n = other.currM2n.clone + this.currM2 = other.currM2.clone + this.currL1 = other.currL1.clone this.totalCnt = other.totalCnt - this.nnz = other.nnz.copy - this.currMax = other.currMax.copy - this.currMin = other.currMin.copy + this.nnz = other.nnz.clone + this.currMax = other.currMax.clone + this.currMin = other.currMin.clone } this } @@ -172,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMean = BDV.zeros[Double](n) + val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { realMean(i) = currMean(i) * (nnz(i) / totalCnt) i += 1 } - Vectors.fromBreeze(realMean) + Vectors.dense(realMean) } override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realVariance = BDV.zeros[Double](n) + val realVariance = Array.ofDim[Double](n) val denominator = totalCnt - 1.0 @@ -199,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S i += 1 } } - - Vectors.fromBreeze(realVariance) + Vectors.dense(realVariance) } override def count: Long = totalCnt @@ -208,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(nnz) + Vectors.dense(nnz) } override def max: Vector = { @@ -219,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMax) + Vectors.dense(currMax) } override def min: Vector = { @@ -230,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMin) + Vectors.dense(currMin) } override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMagnitude = BDV.zeros[Double](n) + val realMagnitude = Array.ofDim[Double](n) var i = 0 while (i < currM2.size) { realMagnitude(i) = math.sqrt(currM2(i)) i += 1 } - - Vectors.fromBreeze(realMagnitude) + Vectors.dense(realMagnitude) } override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(currL1) + + Vectors.dense(currL1) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 78acc17f901c1..3d91867c896d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) - val rfModel = rf.train(input) - rfModel.weakHypotheses(0) + val rfModel = rf.run(input) + rfModel.trees(0) } + /** + * Trains a decision tree model over an RDD. This is deprecated because it hides the static + * methods with the same name in Java. + */ + @deprecated("Please use DecisionTree.run instead.", "1.2.0") + def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input) } object DecisionTree extends Serializable with Logging { @@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging { * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, numClassesForClassification: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala similarity index 51% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala rename to mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index f729344a682e2..61f6b1313f82e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -21,170 +21,117 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum +import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** * :: Experimental :: - * A class that implements Stochastic Gradient Boosting - * for regression and binary classification problems. + * A class that implements + * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]] + * for regression and binary classification. * * The implementation is based upon: * J.H. Friedman. "Stochastic Gradient Boosting." 1999. * - * Notes: - * - This currently can be run with several loss functions. However, only SquaredError is - * fully supported. Specifically, the loss function should be used to compute the gradient - * (to re-label training instances on each iteration) and to weight weak hypotheses. - * Currently, gradients are computed correctly for the available loss functions, - * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. - * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - When the loss is SquaredError, these methods give the same result, but they could differ + * for other loss functions. * - * @param boostingStrategy Parameters for the gradient boosting algorithm + * @param boostingStrategy Parameters for the gradient boosting algorithm. */ @Experimental -class GradientBoosting ( - private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { - - boostingStrategy.weakLearnerParams.algo = Regression - boostingStrategy.weakLearnerParams.impurity = impurity.Variance - - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - boostingStrategy.weakLearnerParams.numClassesForClassification = - boostingStrategy.numClassesForClassification - - boostingStrategy.assertValid() +class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) + extends Serializable with Logging { /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return WeightedEnsembleModel that can be used for prediction + * @return a gradient boosted trees model that can be used for prediction */ - def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { - val algo = boostingStrategy.algo + def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoosting.boost(input, boostingStrategy) + case Regression => GradientBoostedTrees.boost(input, boostingStrategy) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoosting.boost(remappedInput, boostingStrategy) + GradientBoostedTrees.boost(remappedInput, boostingStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } } + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. + */ + def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + run(input.rdd) + } } -object GradientBoosting extends Logging { +object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model. * - * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] - * is recommended to clearly specify regression. - * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] - * is recommended to clearly specify regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction + * @return a gradient boosted trees model that can be used for prediction */ def train( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - new GradientBoosting(boostingStrategy).train(input) + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + new GradientBoostedTrees(boostingStrategy).run(input) } /** - * Method to train a gradient boosting classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainClassifier( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.") - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.") - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] */ def train( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - train(input.rdd, boostingStrategy) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] - */ - def trainClassifier( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - trainClassifier(input.rdd, boostingStrategy) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] - */ - def trainRegressor( input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - trainRegressor(input.rdd, boostingStrategy) + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + train(input.rdd, boostingStrategy) } /** * Internal method for performing regression using trees as base learners. * @param input training dataset * @param boostingStrategy boosting parameters - * @return + * @return a gradient boosted trees model that can be used for prediction */ private def boost( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") timer.start("init") + boostingStrategy.assertValid() + // Initialize gradient boosting parameters val numIterations = boostingStrategy.numIterations val baseLearners = new Array[DecisionTreeModel](numIterations) val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate - val strategy = boostingStrategy.weakLearnerParams + // Prepare strategy for individual trees, which use regression with variance impurity. + val treeStrategy = boostingStrategy.treeStrategy.copy + treeStrategy.algo = Regression + treeStrategy.impurity = Variance + treeStrategy.assertValid() // Cache input if (input.getStorageLevel == StorageLevel.NONE) { @@ -200,11 +147,10 @@ object GradientBoosting extends Logging { // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(strategy).train(data) + val firstTreeModel = new DecisionTree(treeStrategy).run(data) baseLearners(0) = firstTreeModel baseLearnerWeights(0) = 1.0 - val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, - Sum) + val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) logDebug("error of gbt = " + loss.computeError(startingModel, input)) // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") @@ -219,7 +165,7 @@ object GradientBoosting extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") - val model = new DecisionTree(strategy).train(data) + val model = new DecisionTree(treeStrategy).run(data) timer.stop(s"building tree $m") // Create partial model baseLearners(m) = model @@ -228,8 +174,8 @@ object GradientBoosting extends Logging { // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction - val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1), Regression, Sum) + val partialModel = new GradientBoostedTreesModel( + Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) logDebug("error of gbt = " + loss.computeError(partialModel, input)) // Update data with pseudo-residuals data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), @@ -242,8 +188,7 @@ object GradientBoosting extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) - + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 9683916d9b3f1..482d3395516e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,18 +17,18 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache } +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache, + TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -37,7 +37,8 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * A class which implements a random forest learning algorithm for classification and regression. + * A class that implements a [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] + * learning algorithm for classification and regression. * It supports both continuous and categorical features. * * The settings for featureSubsetStrategy are based on the following references: @@ -70,6 +71,47 @@ private class RandomForest ( private val seed: Int) extends Serializable with Logging { + /* + ALGORITHM + This is a sketch of the algorithm to help new developers. + + The algorithm partitions data by instances (rows). + On each iteration, the algorithm splits a set of nodes. In order to choose the best split + for a given node, sufficient statistics are collected from the distributed data. + For each node, the statistics are collected to some worker node, and that worker selects + the best split. + + This setup requires discretization of continuous features. This binning is done in the + findSplitsBins() method during initialization, after which each continuous feature becomes + an ordered discretized feature with at most maxBins possible values. + + The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes + lie at the periphery of the tree being trained. If multiple trees are being trained at once, + then this queue contains nodes from all of them. Each iteration works roughly as follows: + On the master node: + - Some number of nodes are pulled off of the queue (based on the amount of memory + required for their sufficient statistics). + - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate + features are chosen for each node. See method selectNodesToSplit(). + On worker nodes, via method findBestSplits(): + - The worker makes one pass over its subset of instances. + - For each (tree, node, feature, split) tuple, the worker collects statistics about + splitting. Note that the set of (tree, node) pairs is limited to the nodes selected + from the queue for this iteration. The set of features considered can also be limited + based on featureSubsetStrategy. + - For each node, the statistics for that node are aggregated to a particular worker + via reduceByKey(). The designated worker chooses the best (feature, split) pair, + or chooses to stop splitting if the stopping criteria are met. + On the master node: + - The master collects all decisions about splitting nodes and updates the model. + - The updated model is passed to the workers on the next iteration. + This process continues until the node queue is empty. + + Most of the methods in this implementation support the statistics aggregation, which is + the heaviest part of the computation. In general, this implementation is bound by either + the cost of statistics computation on workers or by communicating the sufficient statistics. + */ + strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), @@ -79,9 +121,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ - def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { + def run(input: RDD[LabeledPoint]): RandomForestModel = { val timer = new TimeTracker() @@ -212,8 +254,7 @@ private class RandomForest ( } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) - val treeWeights = Array.fill[Double](numTrees)(1.0) - new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average) + new RandomForestModel(strategy.algo, trees) } } @@ -231,21 +272,20 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "sqrt". * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { require(strategy.algo == Classification, s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -262,8 +302,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". * @param maxDepth Maximum depth of the tree. @@ -272,7 +311,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], @@ -283,7 +322,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { + seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) @@ -302,7 +341,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { trainClassifier(input.rdd, numClassesForClassification, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -319,21 +358,20 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "onethird". * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { require(strategy.algo == Regression, s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -349,8 +387,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "onethird". * @param impurity Criterion used for information gain calculation. * Supported values: "variance". * @param maxDepth Maximum depth of the tree. @@ -359,7 +396,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], @@ -369,7 +406,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { + seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) @@ -387,7 +424,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { trainRegressor(input.rdd, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -479,5 +516,4 @@ object RandomForest extends Serializable with Logging { 3 * totalBins } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index abbda040bd528..e703adbdbfbb3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -25,57 +25,39 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** * :: Experimental :: - * Stores all the configuration options for the boosting algorithms - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. + * + * @param treeStrategy Parameters for the tree algorithm. We support regression and binary + * classification for boosting. Impurity setting will be ignored. + * @param loss Loss function used for minimization during gradient boosting. * @param numIterations Number of iterations of boosting. In other words, the number of * weak hypotheses used in the final model. - * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * This setting overrides any setting in [[weakLearnerParams]]. - * Default value is 2 (binary classification). - * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are - * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - @BeanProperty var algo: Algo, - @BeanProperty var numIterations: Int, + @BeanProperty var treeStrategy: Strategy, @BeanProperty var loss: Loss, // Optional boosting parameters - @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.numClassesForClassification = numClassesForClassification - - /** - * Sets Algorithm using a String. - */ - def setAlgo(algo: String): Unit = algo match { - case "Classification" => setAlgo(Classification) - case "Regression" => setAlgo(Regression) - } + @BeanProperty var numIterations: Int = 100, + @BeanProperty var learningRate: Double = 0.1) extends Serializable { /** * Check validity of parameters. * Throws exception if invalid. */ private[tree] def assertValid(): Unit = { - algo match { + treeStrategy.algo match { case Classification => - require(numClassesForClassification == 2) + require(treeStrategy.numClassesForClassification == 2, + "Only binary classification is supported for boosting.") case Regression => // nothing case _ => throw new IllegalArgumentException( - s"BoostingStrategy given invalid algo parameter: $algo." + + s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." + s" Valid settings are: Classification, Regression.") } require(learningRate > 0 && learningRate <= 1, @@ -94,14 +76,14 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy("Regression") + val treeStrategy = Strategy.defaultStrategy(algo) treeStrategy.maxDepth = 3 algo match { case "Classification" => - new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + treeStrategy.numClassesForClassification = 2 + new BoostingStrategy(treeStrategy, LogLoss) case "Regression" => - new BoostingStrategy(Algo.withName(algo), 100, SquaredError, - weakLearnerParams = treeStrategy) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala index 82889dc00cdad..b5bf732d1b33a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala @@ -17,14 +17,10 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.DeveloperApi - /** - * :: Experimental :: * Enum to select ensemble combining strategy for base learners */ -@DeveloperApi -object EnsembleCombiningStrategy extends Enumeration { +private[tree] object EnsembleCombiningStrategy extends Enumeration { type EnsembleCombiningStrategy = Value - val Sum, Average = Value + val Average, Sum, Vote = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b5b1f82177edc..d75f38433c081 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -157,6 +157,13 @@ class Strategy ( require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } + + /** Returns a shallow copy of this instance. */ + def copy: Strategy = { + new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + } } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index d111ffe30ed9e..d1bde15e6b150 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -17,19 +17,18 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: - * Class for least absolute error loss calculation. - * The features x and the corresponding label y is predicted using the function F. - * For each instance: - * Loss: |y - F| - * Negative gradient: sign(y - F) + * Class for absolute error loss calculation (for regression). + * + * The absolute (L1) error is defined as: + * |y - F(x)| + * where y is the label and F(x) is the model prediction for features x. */ @DeveloperApi object AbsoluteError extends Loss { @@ -37,30 +36,29 @@ object AbsoluteError extends Loss { /** * Method to calculate the gradients for the gradient boosting calculation for least * absolute error calculation. - * @param model Model of the weak learner + * The gradient with respect to F(x) is: sign(F(x) - y) + * @param model Ensemble model * @param point Instance of the training dataset * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 } /** - * Method to calculate error of the base learner for the gradient boosting calculation. + * Method to calculate loss of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param model Model of the weak learner. + * @param model Ensemble model * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Mean absolute error of model on data */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { - val sumOfAbsolutes = data.map { y => + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { y => val err = model.predict(y.features) - y.label math.abs(err) - }.sum() - sumOfAbsolutes / data.count() + }.mean() } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 6f3d4340f0d3b..7ce9fa6f86c42 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -19,17 +19,17 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: - * Class for least squares error loss calculation. + * Class for log loss calculation (for classification). + * This uses twice the binomial negative log likelihood, called "deviance" in Friedman (1999). * - * The features x and the corresponding label y is predicted using the function F. - * For each instance: - * Loss: log(1 + exp(-2yF)), y in {-1, 1} - * Negative gradient: 2y / ( 1 + exp(2yF)) + * The log loss is defined as: + * 2 log(1 + exp(-2 y F(x))) + * where y is a label in {-1, 1} and F(x) is the model prediction for features x. */ @DeveloperApi object LogLoss extends Loss { @@ -37,27 +37,37 @@ object LogLoss extends Loss { /** * Method to calculate the loss gradients for the gradient boosting calculation for binary * classification - * @param model Model of the weak learner + * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x))) + * @param model Ensemble model * @param point Instance of the training dataset * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { val prediction = model.predict(point.features) - 1.0 / (1.0 + math.exp(-prediction)) - point.label + - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) } /** - * Method to calculate error of the base learner for the gradient boosting calculation. + * Method to calculate loss of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param model Model of the weak learner. + * @param model Ensemble model * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Mean log loss of model on data */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { - val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count() - wrongPredictions / data.count + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { case point => + val prediction = model.predict(point.features) + val margin = 2.0 * point.label * prediction + // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically + // stable. + if (margin >= 0) { + 2.0 * math.log1p(math.exp(-margin)) + } else { + 2.0 * (-margin + math.log1p(math.exp(margin))) + } + }.mean() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 5580866c879e2..4bca9039ebe1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -36,7 +36,7 @@ trait Loss extends Serializable { * @return Loss gradient. */ def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double /** @@ -47,6 +47,6 @@ trait Loss extends Serializable { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double + def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 4349fefef2c74..50ecaa2f86f35 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -17,20 +17,18 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: - * Class for least squares error loss calculation. + * Class for squared error loss calculation. * - * The features x and the corresponding label y is predicted using the function F. - * For each instance: - * Loss: (y - F)**2/2 - * Negative gradient: y - F + * The squared (L2) error is defined as: + * (y - F(x))**2 + * where y is the label and F(x) is the model prediction for features x. */ @DeveloperApi object SquaredError extends Loss { @@ -38,29 +36,29 @@ object SquaredError extends Loss { /** * Method to calculate the gradients for the gradient boosting calculation for least * squares error calculation. - * @param model Model of the weak learner + * The gradient with respect to F(x) is: - 2 (y - F(x)) + * @param model Ensemble model * @param point Instance of the training dataset * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { - model.predict(point.features) - point.label + 2.0 * (model.predict(point.features) - point.label) } /** - * Method to calculate error of the base learner for the gradient boosting calculation. + * Method to calculate loss of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param model Model of the weak learner. + * @param model Ensemble model * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Mean squared error of model on data */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map { y => val err = model.predict(y.features) - y.label err * err }.mean() } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ec1d99ab26f9c..a5760963068c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -18,9 +18,10 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: @@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } + + /** + * Predict values for the given data set using the model trained. + * + * @param features JavaRDD representing data points to be predicted + * @return JavaRDD of predictions for each of the given data points + */ + def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { + predict(features.rdd) + } + /** * Get number of nodes in tree, including leaf nodes. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala deleted file mode 100644 index 7b052d9163a13..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ -import org.apache.spark.rdd.RDD - -import scala.collection.mutable - -@Experimental -class WeightedEnsembleModel( - val weakHypotheses: Array[DecisionTreeModel], - val weakHypothesisWeights: Array[Double], - val algo: Algo, - val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { - - require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" + - s". Number of weakHypotheses = $weakHypotheses") - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - private def predictRaw(features: Vector): Double = { - val treePredictions = weakHypotheses.map(learner => learner.predict(features)) - if (numWeakHypotheses == 1){ - treePredictions(0) - } else { - var prediction = treePredictions(0) - var index = 1 - while (index < numWeakHypotheses) { - prediction += weakHypothesisWeights(index) * treePredictions(index) - index += 1 - } - prediction - } - } - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - private def predictBySumming(features: Vector): Double = { - algo match { - case Regression => predictRaw(features) - case Classification => { - // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. - if (predictRaw(features) > 0 ) 1.0 else 0.0 - } - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown algo parameter: $algo.") - } - } - - /** - * Predict values for a single data point. - * - * @param features array representing a single data point - * @return Double prediction from the trained model - */ - private def predictByAveraging(features: Vector): Double = { - algo match { - case Classification => - val predictionToCount = new mutable.HashMap[Int, Int]() - weakHypotheses.foreach { learner => - val prediction = learner.predict(features).toInt - predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 - } - predictionToCount.maxBy(_._2)._1 - case Regression => - weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size - } - } - - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - def predict(features: Vector): Double = { - combiningStrategy match { - case Sum => predictBySumming(features) - case Average => predictByAveraging(features) - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.") - } - } - - /** - * Predict values for the given data set. - * - * @param features RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) - - /** - * Print a summary of the model. - */ - override def toString: String = { - algo match { - case Classification => - s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n" - case Regression => - s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n" - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown algo parameter: $algo.") - } - } - - /** - * Print the full model to a string. - */ - def toDebugString: String = { - val header = toString + "\n" - header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) => - s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) - }.fold("")(_ + _) - } - - /** - * Get number of trees in forest. - */ - def numWeakHypotheses: Int = weakHypotheses.size - - // TODO: Remove these helpers methods once class is generalized to support any base learning - // algorithms. - - /** - * Get total number of nodes, summed over all trees in the forest. - */ - def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala new file mode 100644 index 0000000000000..22997110de8dd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Represents a random forest model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + */ +@Experimental +class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) + extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), + combiningStrategy = if (algo == Classification) Vote else Average) { + + require(trees.forall(_.algo == algo)) +} + +/** + * :: Experimental :: + * Represents a gradient boosted trees model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + */ +@Experimental +class GradientBoostedTreesModel( + override val algo: Algo, + override val trees: Array[DecisionTreeModel], + override val treeWeights: Array[Double]) + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { + + require(trees.size == treeWeights.size) +} + +/** + * Represents a tree ensemble model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + * @param combiningStrategy strategy for combining the predictions, not used for regression. + */ +private[tree] sealed class TreeEnsembleModel( + protected val algo: Algo, + protected val trees: Array[DecisionTreeModel], + protected val treeWeights: Array[Double], + protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { + + require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.") + + private val sumWeights = math.max(treeWeights.sum, 1e-15) + + /** + * Predicts for a single data point using the weighted sum of ensemble predictions. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictBySumming(features: Vector): Double = { + val treePredictions = trees.map(_.predict(features)) + blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) + } + + /** + * Classifies a single data point based on (weighted) majority votes. + */ + private def predictByVoting(features: Vector): Double = { + val votes = mutable.Map.empty[Int, Double] + trees.view.zip(treeWeights).foreach { case (tree, weight) => + val prediction = tree.predict(features).toInt + votes(prediction) = votes.getOrElse(prediction, 0.0) + weight + } + votes.maxBy(_._2)._1 + } + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + def predict(features: Vector): Double = { + (algo, combiningStrategy) match { + case (Regression, Sum) => + predictBySumming(features) + case (Regression, Average) => + predictBySumming(features) / sumWeights + case (Classification, Sum) => // binary classification + val prediction = predictBySumming(features) + // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. + if (prediction > 0.0) 1.0 else 0.0 + case (Classification, Vote) => + predictByVoting(features) + case _ => + throw new IllegalArgumentException( + "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " + + s"($algo, $combiningStrategy).") + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) + + /** + * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]]. + */ + def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = { + predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] + } + + /** + * Print a summary of the model. + */ + override def toString: String = { + algo match { + case Classification => + s"TreeEnsembleModel classifier with $numTrees trees\n" + case Regression => + s"TreeEnsembleModel regressor with $numTrees trees\n" + case _ => throw new IllegalArgumentException( + s"TreeEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** + * Get number of trees in forest. + */ + def numTrees: Int = trees.size + + /** + * Get total number of nodes, summed over all trees in the forest. + */ + def totalNumNodes: Int = trees.map(_.numNodes).sum +} diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java new file mode 100644 index 0000000000000..42846677ed285 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +/** + * Test Pipeline construction and fitting in Java. + */ +public class JavaPipelineSuite { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPipelineSuite"); + jsql = new JavaSQLContext(jsc); + JavaRDD points = + jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); + dataset = jsql.applySchema(points, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void pipeline() { + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + LogisticRegression lr = new LogisticRegression() + .setFeaturesCol("scaledFeatures"); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {scaler, lr}); + PipelineModel model = pipeline.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java new file mode 100644 index 0000000000000..76eb7f00329f2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaLogisticRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + jsql = new JavaSQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void logisticRegression() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionWithSetters() { + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold + .registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionFitWithVarargs() { + LogisticRegression lr = new LogisticRegression(); + lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java new file mode 100644 index 0000000000000..a266ebd2071a1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaCrossValidatorSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); + jsql = new JavaSQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void crossValidationWithLogisticRegression() { + LogisticRegression lr = new LogisticRegression(); + ParamMap[] lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.001, 1000.0}) + .addGrid(lr.maxIter(), new int[] {0, 10}) + .build(); + BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); + CrossValidator cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3); + CrossValidatorModel cvModel = cv.fit(dataset); + ParamMap bestParamMap = cvModel.bestModel().fittingParamMap(); + Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam())); + Assert.assertEquals(10, bestParamMap.apply(lr.maxIter())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index f6ca9643227f8..af688c504cf1e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -23,13 +23,14 @@ import scala.Tuple2; import scala.Tuple3; +import com.google.common.collect.Lists; import org.jblas.DoubleMatrix; - import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -47,61 +48,48 @@ public void tearDown() { sc = null; } - static void validatePrediction( + void validatePrediction( MatrixFactorizationModel model, int users, int products, - int features, DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (Tuple2 userFeature : userFeatures) { - predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); - } - } - DoubleMatrix predictedP = new DoubleMatrix(products, features); - - List> productFeatures = - model.productFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (Tuple2 productFeature : productFeatures) { - predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); + List> localUsersProducts = + Lists.newArrayListWithCapacity(users * products); + for (int u=0; u < users; ++u) { + for (int p=0; p < products; ++p) { + localUsersProducts.add(new Tuple2(u, p)); } } - - DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); - + JavaPairRDD usersProducts = sc.parallelizePairs(localUsersProducts); + List predictedRatings = model.predict(usersProducts).collect(); + Assert.assertEquals(users * products, predictedRatings.size()); if (!implicitPrefs) { - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double correct = trueRatings.get(u, p); - Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", - prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); - } + for (Rating r: predictedRatings) { + double prediction = r.rating(); + double correct = trueRatings.get(r.user(), r.product()); + Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", + prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); } } else { // For implicit prefs we use the confidence-weighted RMSE to test // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double truePref = truePrefs.get(u, p); - double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p)); - double err = confidence * (truePref - prediction) * (truePref - prediction); - sqErr += err; - denom += confidence; - } + for (Rating r: predictedRatings) { + double prediction = r.rating(); + double truePref = truePrefs.get(r.user(), r.product()); + double confidence = 1.0 + + /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product())); + double err = confidence * (truePref - prediction) * (truePref - prediction); + sqErr += err; + denom += confidence; } double rmse = Math.sqrt(sqErr / denom); Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", - rmse, matchThreshold), rmse < matchThreshold); + rmse, matchThreshold), rmse < matchThreshold); } } @@ -116,7 +104,7 @@ public void runALSUsingStaticMethods() { JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @Test @@ -132,8 +120,8 @@ public void runALSUsingConstructor() { MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) - .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + .run(data); + validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @Test @@ -147,7 +135,7 @@ public void runImplicitALSUsingStaticMethods() { JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test @@ -165,7 +153,7 @@ public void runImplicitALSUsingConstructor() { .setIterations(iterations) .setImplicitPrefs(true) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test @@ -183,7 +171,7 @@ public void runImplicitALSWithNegativeWeight() { .setImplicitPrefs(true) .setSeed(8675309L) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 2c281a1ee7157..9925aae441af9 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -74,7 +74,7 @@ public void runDTUsingConstructor() { maxBins, categoricalFeaturesInfo); DecisionTree learner = new DecisionTree(strategy); - DecisionTreeModel model = learner.train(rdd.rdd()); + DecisionTreeModel model = learner.run(rdd.rdd()); int numCorrect = validatePrediction(arr, model); Assert.assertTrue(numCorrect == rdd.count()); diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala new file mode 100644 index 0000000000000..4515084bc7ae9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.when +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +class PipelineSuite extends FunSuite { + + abstract class MyModel extends Model[MyModel] + + test("pipeline") { + val estimator0 = mock[Estimator[MyModel]] + val model0 = mock[MyModel] + val transformer1 = mock[Transformer] + val estimator2 = mock[Estimator[MyModel]] + val model2 = mock[MyModel] + val transformer3 = mock[Transformer] + val dataset0 = mock[SchemaRDD] + val dataset1 = mock[SchemaRDD] + val dataset2 = mock[SchemaRDD] + val dataset3 = mock[SchemaRDD] + val dataset4 = mock[SchemaRDD] + + when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) + when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) + when(model0.parent).thenReturn(estimator0) + when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2) + when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3) + when(model2.parent).thenReturn(estimator2) + when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4) + + val pipeline = new Pipeline() + .setStages(Array(estimator0, transformer1, estimator2, transformer3)) + val pipelineModel = pipeline.fit(dataset0) + + assert(pipelineModel.stages.size === 4) + assert(pipelineModel.stages(0).eq(model0)) + assert(pipelineModel.stages(1).eq(transformer1)) + assert(pipelineModel.stages(2).eq(model2)) + assert(pipelineModel.stages(3).eq(transformer3)) + + assert(pipelineModel.getModel(estimator0).eq(model0)) + assert(pipelineModel.getModel(estimator2).eq(model2)) + intercept[NoSuchElementException] { + pipelineModel.getModel(mock[Estimator[MyModel]]) + } + val output = pipelineModel.transform(dataset0) + assert(output.eq(dataset4)) + } + + test("pipeline with duplicate stages") { + val estimator = mock[Estimator[MyModel]] + val pipeline = new Pipeline() + .setStages(Array(estimator, estimator)) + val dataset = mock[SchemaRDD] + intercept[IllegalArgumentException] { + pipeline.fit(dataset) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala new file mode 100644 index 0000000000000..e8030fef55b1d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{SQLContext, SchemaRDD} + +class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: SchemaRDD = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataset = sqlContext.createSchemaRDD( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + } + + test("logistic regression") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression + val model = lr.fit(dataset) + model.transform(dataset) + .select('label, 'prediction) + .collect() + } + + test("logistic regression with setters") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + val model = lr.fit(dataset) + model.transform(dataset, model.threshold -> 0.8) // overwrite threshold + .select('label, 'score, 'prediction) + .collect() + } + + test("logistic regression fit and transform with varargs") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression + val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) + model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") + .select('label, 'probability, 'prediction) + .collect() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala new file mode 100644 index 0000000000000..1ce2987612378 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import org.scalatest.FunSuite + +class ParamsSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param") { + assert(maxIter.name === "maxIter") + assert(maxIter.doc === "max number of iterations") + assert(maxIter.defaultValue.get === 100) + assert(maxIter.parent.eq(solver)) + assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") + assert(inputCol.defaultValue === None) + } + + test("param pair") { + val pair0 = maxIter -> 5 + val pair1 = maxIter.w(5) + val pair2 = ParamPair(maxIter, 5) + for (pair <- Seq(pair0, pair1, pair2)) { + assert(pair.param.eq(maxIter)) + assert(pair.value === 5) + } + } + + test("param map") { + val map0 = ParamMap.empty + + assert(!map0.contains(maxIter)) + assert(map0(maxIter) === maxIter.defaultValue.get) + map0.put(maxIter, 10) + assert(map0.contains(maxIter)) + assert(map0(maxIter) === 10) + + assert(!map0.contains(inputCol)) + intercept[NoSuchElementException] { + map0(inputCol) + } + map0.put(inputCol -> "input") + assert(map0.contains(inputCol)) + assert(map0(inputCol) === "input") + + val map1 = map0.copy + val map2 = ParamMap(maxIter -> 10, inputCol -> "input") + val map3 = new ParamMap() + .put(maxIter, 10) + .put(inputCol, "input") + val map4 = ParamMap.empty ++ map0 + val map5 = ParamMap.empty + map5 ++= map0 + + for (m <- Seq(map1, map2, map3, map4, map5)) { + assert(m.contains(maxIter)) + assert(m(maxIter) === 10) + assert(m.contains(inputCol)) + assert(m(inputCol) === "input") + } + } + + test("params") { + val params = solver.params + assert(params.size === 2) + assert(params(0).eq(inputCol), "params must be ordered by name") + assert(params(1).eq(maxIter)) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) + assert(solver.getParam("maxIter").eq(maxIter)) + intercept[NoSuchMethodException] { + solver.getParam("abc") + } + assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { + solver.validate() + } + solver.validate(ParamMap(inputCol -> "input")) + solver.setInputCol("input") + assert(solver.isSet(inputCol)) + assert(solver.getInputCol === "input") + solver.validate() + intercept[IllegalArgumentException] { + solver.validate(ParamMap(maxIter -> -10)) + } + solver.setMaxIter(-10) + intercept[IllegalArgumentException] { + solver.validate() + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala new file mode 100644 index 0000000000000..1a65883d78a71 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +/** A subclass of Params for testing. */ +class TestParams extends Params { + + val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) + def setMaxIter(value: Int): this.type = { set(maxIter, value); this } + def getMaxIter: Int = get(maxIter) + + val inputCol = new Param[String](this, "inputCol", "input column name") + def setInputCol(value: String): this.type = { set(inputCol, value); this } + def getInputCol: String = get(inputCol) + + override def validate(paramMap: ParamMap) = { + val m = this.paramMap ++ paramMap + require(m(maxIter) >= 0) + require(m.contains(inputCol)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala new file mode 100644 index 0000000000000..41cc13da4d5b1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.scalatest.FunSuite + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{SQLContext, SchemaRDD} + +class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { + + @transient var dataset: SchemaRDD = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sqlContext = new SQLContext(sc) + dataset = sqlContext.createSchemaRDD( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + } + + test("cross validation with logistic regression") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val bestParamMap = cvModel.bestModel.fittingParamMap + assert(bestParamMap(lr.regParam) === 0.001) + assert(bestParamMap(lr.maxIter) === 10) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala new file mode 100644 index 0000000000000..20aa100112bfe --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.ml.param.{ParamMap, TestParams} + +class ParamGridBuilderSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param grid builder") { + def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = { + assert(maps.size === expected.size) + maps.foreach { m => + val tuple = (m(maxIter), m(inputCol)) + assert(expected.contains(tuple)) + expected.remove(tuple) + } + assert(expected.isEmpty) + } + + val maps0 = new ParamGridBuilder() + .baseOn(maxIter -> 10) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected0 = mutable.Set( + (10, "input0"), + (10, "input1")) + validateGrid(maps0, expected0) + + val maps1 = new ParamGridBuilder() + .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten + .addGrid(maxIter, Array(10, 20)) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected1 = mutable.Set( + (10, "input0"), + (20, "input0"), + (10, "input1"), + (20, "input1")) + validateGrid(maps1, expected1) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index e954baaf7d91e..4e812994405b3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object LogisticRegressionSuite { @@ -57,7 +57,7 @@ object LogisticRegressionSuite { } } -class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers { +class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -80,13 +80,16 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val testRDD = sc.parallelize(testData, 2) testRDD.cache() val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer.setStepSize(10.0).setNumIterations(20) + lr.optimizer + .setStepSize(10.0) + .setRegParam(0.0) + .setNumIterations(20) val model = lr.run(testRDD) // Test the weights - assert(model.weights(0) ~== -1.52 relTol 0.01) - assert(model.intercept ~== 2.00 relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -112,10 +115,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD) // Test the weights - assert(model.weights(0) ~== -1.52 relTol 0.01) - assert(model.intercept ~== 2.00 relTol 0.01) - assert(model.weights(0) ~== model.weights(0) relTol 0.01) - assert(model.intercept ~== model.intercept relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -141,13 +142,16 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Use half as many iterations as the previous test. val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer.setStepSize(10.0).setNumIterations(10) + lr.optimizer + .setStepSize(10.0) + .setRegParam(0.0) + .setNumIterations(10) val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -1.50 relTol 0.01) - assert(model.intercept ~== 1.97 relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -212,8 +216,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -1.50 relTol 0.02) - assert(model.intercept ~== 1.97 relTol 0.02) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 80989bc074e84..e68fe89d6ccea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} object NaiveBayesSuite { @@ -60,7 +60,7 @@ object NaiveBayesSuite { } } -class NaiveBayesSuite extends FunSuite with LocalSparkContext { +class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOfPredictions = predictions.zip(input).count { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 65e5df58db4c7..a2de7fbd41383 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} object SVMSuite { @@ -58,7 +58,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with LocalSparkContext { +class SVMSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index afa1f79b95a12..9ebef8466c831 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -22,10 +22,10 @@ import scala.util.Random import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class KMeansSuite extends FunSuite with LocalSparkContext { +class KMeansSuite extends FunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 994e0feb8629e..79847633ff0dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { +class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index a733f88b60b80..8a18e2971cab6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -19,44 +19,109 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { +class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { - def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 + private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 - def cond2(x: ((Double, Double), (Double, Double))): Boolean = + private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean = (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) + private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = { + assert(left.zip(right).forall(areWithinEpsilon)) + } + + private def assertTupleSequencesMatch(left: Seq[(Double, Double)], + right: Seq[(Double, Double)]): Unit = { + assert(left.zip(right).forall(pairsWithinEpsilon)) + } + + private def validateMetrics(metrics: BinaryClassificationMetrics, + expectedThresholds: Seq[Double], + expectedROCCurve: Seq[(Double, Double)], + expectedPRCurve: Seq[(Double, Double)], + expectedFMeasures1: Seq[Double], + expectedFmeasures2: Seq[Double], + expectedPrecisions: Seq[Double], + expectedRecalls: Seq[Double]) = { + + assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds) + assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve) + assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5) + assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve) + assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5) + assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), + expectedThresholds.zip(expectedFMeasures1)) + assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), + expectedThresholds.zip(expectedFmeasures2)) + assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), + expectedThresholds.zip(expectedPrecisions)) + assertTupleSequencesMatch(metrics.recallByThreshold().collect(), + expectedThresholds.zip(expectedRecalls)) + } + test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) val metrics = new BinaryClassificationMetrics(scoreAndLabels) - val threshold = Seq(0.8, 0.6, 0.4, 0.1) + val thresholds = Seq(0.8, 0.6, 0.4, 0.1) val numTruePositives = Seq(1, 3, 3, 4) val numFalsePositives = Seq(0, 1, 2, 3) val numPositives = 4 val numNegatives = 3 - val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) => + val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) => t.toDouble / (t + f) } - val recall = numTruePositives.map(t => t.toDouble / numPositives) + val recalls = numTruePositives.map(t => t.toDouble / numPositives) val fpr = numFalsePositives.map(f => f.toDouble / numNegatives) - val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) - val pr = recall.zip(precision) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(metrics.thresholds().collect().zip(threshold).forall(cond1)) - assert(metrics.roc().collect().zip(rocCurve).forall(cond2)) - assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5) - assert(metrics.pr().collect().zip(prCurve).forall(cond2)) - assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5) - assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2)) - assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2)) - assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2)) - assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2)) + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + } + + test("binary evaluation metrics for RDD where all examples have positive label") { + val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + + val thresholds = Seq(0.5) + val precisions = Seq(1.0) + val recalls = Seq(1.0) + val fpr = Seq(0.0) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + } + + test("binary evaluation metrics for RDD where all examples have negative label") { + val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + + val thresholds = Seq(0.5) + val precisions = Seq(0.0) + val recalls = Seq(0.0) + val fpr = Seq(1.0) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { + case (0, 0) => 0.0 + case (r, p) => 2.0 * (p * r) / (p + r) + } + val f2 = pr.map { + case (0, 0) => 0.0 + case (r, p) => 5.0 * (p * r) / (4.0 * p + r) + } + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 1ea503971c864..7dc4f3cfbc4e4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Matrices -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { +class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index 342baa0274e9c..2537dd62c92f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { +class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { test("Multilabel evaluation metrics") { /* * Documents true labels (5x class0, 3x class1, 4x class2): diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index a2d4bb41484b8..609eed983ff4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class RankingMetricsSuite extends FunSuite with LocalSparkContext { +class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 5396d7b2b74fa..670b4c34e6095 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class RegressionMetricsSuite extends FunSuite with LocalSparkContext { +class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { test("regression metrics") { val predictionAndObservations = sc.parallelize( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index a599e0d938569..0c4dfb7b97c7f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with LocalSparkContext { +class HashingTFSuite extends FunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 43974f84e3ca8..30147e7fd948f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.FunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with LocalSparkContext { +class IDFSuite extends FunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 2bf9d9816ae45..85fdd271b5ed1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -22,10 +22,10 @@ import org.scalatest.FunSuite import breeze.linalg.{norm => brzNorm} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with LocalSparkContext { +class NormalizerSuite extends FunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index e217b93cebbdb..4c93c0ca4f86c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with LocalSparkContext { +class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { data.treeAggregate(new MultivariateOnlineSummarizer)( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index e34335d89eb75..52278690dbd89 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class Word2VecSuite extends FunSuite with LocalSparkContext { +class Word2VecSuite extends FunSuite with MLlibTestSparkContext { // TODO: add more tests diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 5f8b8c4b72697..322a0e9242918 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -17,7 +17,11 @@ package org.apache.spark.mllib.linalg +import java.util.Random + +import org.mockito.Mockito.when import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -112,4 +116,50 @@ class MatricesSuite extends FunSuite { assert(sparseMat(0, 1) === 10.0) assert(sparseMat.values(2) === 10.0) } + + test("zeros") { + val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 3) + assert(mat.values.forall(_ == 0.0)) + } + + test("ones") { + val mat = Matrices.ones(2, 3).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 3) + assert(mat.values.forall(_ == 1.0)) + } + + test("eye") { + val mat = Matrices.eye(2).asInstanceOf[DenseMatrix] + assert(mat.numCols === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 1.0)) + } + + test("rand") { + val rng = mock[Random] + when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = Matrices.rand(2, 2, rng).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } + + test("randn") { + val rng = mock[Random] + when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = Matrices.randn(2, 2, rng).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } + + test("diag") { + val mat = Matrices.diag(Vectors.dense(1.0, 2.0)).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 93a84fe07b32a..9492f604af4d5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.linalg +import breeze.linalg.{DenseMatrix => BDM} import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -166,4 +167,34 @@ class VectorsSuite extends FunSuite { assert(v === udt.deserialize(udt.serialize(v))) } } + + test("fromBreeze") { + val x = BDM.zeros[Double](10, 10) + val v = Vectors.fromBreeze(x(::, 0)) + assert(v.size === x.rows) + } + + test("foreachActive") { + val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0) + val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0))) + + val dvMap = scala.collection.mutable.Map[Int, Double]() + dv.foreachActive { (index, value) => + dvMap.put(index, value) + } + assert(dvMap.size === 4) + assert(dvMap.get(0) === Some(0.0)) + assert(dvMap.get(1) === Some(1.2)) + assert(dvMap.get(2) === Some(3.1)) + assert(dvMap.get(3) === Some(0.0)) + + val svMap = scala.collection.mutable.Map[Int, Double]() + sv.foreachActive { (index, value) => + svMap.put(index, value) + } + assert(svMap.size === 3) + assert(svMap.get(1) === Some(1.2)) + assert(svMap.get(2) === Some(3.1)) + assert(svMap.get(3) === Some(0.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index cd45438fb628f..f8709751efce6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.FunSuite import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with LocalSparkContext { +class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index f7c46f23b746d..e25bc02b06c9a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -21,11 +21,11 @@ import org.scalatest.FunSuite import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with LocalSparkContext { +class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 63f3ed58c0d4d..dbf55ff81ca99 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -23,9 +23,9 @@ import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, s import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with LocalSparkContext { +class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index bf040110e228b..86481c6e66200 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -61,7 +61,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers { +class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index ccba004baa007..70c64775e4c04 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -23,10 +23,10 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { +class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index c50b78bcbcc61..ea5889b3ecd5e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter @@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable { +class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 4ef67a40b9f49..681ce9263933b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.rdd import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with LocalSparkContext { +class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 017c39edb185f..603d0ad127b86 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.jblas.DoubleMatrix import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.recommendation.ALS.BlockStats object ALSSuite { @@ -85,7 +85,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with LocalSparkContext { +class ALSSuite extends FunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala new file mode 100644 index 0000000000000..b9caecc904a23 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { + + val rank = 2 + var userFeatures: RDD[(Int, Array[Double])] = _ + var prodFeatures: RDD[(Int, Array[Double])] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0)))) + prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0)))) + } + + test("constructor") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + assert(model.predict(0, 2) ~== 17.0 relTol 1e-14) + + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(1, userFeatures, prodFeatures) + } + + val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0)))) + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(rank, userFeatures1, prodFeatures) + } + + val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0)))) + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 7aa96421aed87..2668dcc14a842 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -23,9 +23,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class LassoSuite extends FunSuite with LocalSparkContext { +class LassoSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 4f89112b650c5..864622a9296a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -23,9 +23,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class LinearRegressionSuite extends FunSuite with LocalSparkContext { +class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 727bbd051ff15..18d3bf5ea4eca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -24,9 +24,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class RidgeRegressionSuite extends FunSuite with LocalSparkContext { +class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { predictions.zip(input).map { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index 34548c86ebc14..d20a09b4b4925 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -24,9 +24,9 @@ import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with LocalSparkContext { +class CorrelationSuite extends FunSuite with MLlibTestSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 6de3840b3f198..15418e6035965 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -25,10 +25,10 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with LocalSparkContext { +class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 1e9415249104b..23b0eec865de6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } + + test("merging summarizer when one side has zero mean (SPARK-4355)") { + val s0 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(2.0)) + .add(Vectors.dense(2.0)) + val s1 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(1.0)) + .add(Vectors.dense(-1.0)) + s0.merge(s1) + assert(s0.mean(0) ~== 1.0 absTol 1e-14) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index c579cb58549f5..972c905ec9ffa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class DecisionTreeSuite extends FunSuite with LocalSparkContext { +class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index effb7b8259ffb..8972c229b7ecb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.util.StatCounter import scala.collection.mutable @@ -48,7 +48,7 @@ object EnsembleTestHelper { } def validateClassifier( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, input: Seq[LabeledPoint], requiredAccuracy: Double) { val predictions = input.map(x => model.predict(x.features)) @@ -60,17 +60,27 @@ object EnsembleTestHelper { s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") } + /** + * Validates a tree ensemble model for regression. + */ def validateRegressor( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, input: Seq[LabeledPoint], - requiredMSE: Double) { + required: Double, + metricName: String = "mse") { val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => + prediction - label + } + val metric = metricName match { + case "mse" => + errors.map(err => err * err).sum / errors.size + case "mae" => + errors.map(math.abs).sum / errors.size + } + + assert(metric <= required, + s"validateRegressor calculated $metricName $metric but required $required.") } def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala new file mode 100644 index 0000000000000..d4d54cf4c9e2a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[GradientBoostedTrees]]. + */ +class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { + + test("Regression with continuous features: SquaredError") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + } + + test("Regression with continuous features: Absolute Error") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, AbsoluteError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae") + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + + test("Binary classification with continuous features: Log Loss") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty, + subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, LogLoss, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val ensembleStrategy = treeStrategy.copy + ensembleStrategy.algo = Regression + ensembleStrategy.impurity = Variance + val dt = DecisionTree.train(remappedInput, ensembleStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + +} + +object GradientBoostedTreesSuite { + + // Combinations for estimators, learning rates and subsamplingRate + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + val randomSeeds = Array(681283, 4398) + + val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala deleted file mode 100644 index 99a02eda60baf..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree - -import org.scalatest.FunSuite - -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} -import org.apache.spark.mllib.tree.impurity.Variance -import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} - -import org.apache.spark.mllib.util.LocalSparkContext - -/** - * Test suite for [[GradientBoosting]]. - */ -class GradientBoostingSuite extends FunSuite with LocalSparkContext { - - test("Regression with continuous features: SquaredError") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, - learningRate, 1, treeStrategy) - - val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - - test("Regression with continuous features: Absolute Error") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, - learningRate, numClassesForClassification = 2, treeStrategy) - - val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - - test("Binary classification with continuous features: Log Loss") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss, - learningRate, numClassesForClassification = 2, treeStrategy) - - val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - -} - -object GradientBoostingSuite { - - // Combinations for estimators, learning rates and subsamplingRate - val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) - -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 73c4393c3581a..90a8c2dfdab80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -28,12 +28,12 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} import org.apache.spark.mllib.tree.model.Node -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with LocalSparkContext { +class RandomForestSuite extends FunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) @@ -41,8 +41,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.weakHypotheses.size === 1) - val rfTree = rf.weakHypotheses(0) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -65,7 +65,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val categoricalFeaturesInfo = Map.empty[Int, Int] val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) binaryClassificationTestWithContinuousFeatures(strategy) } @@ -76,8 +77,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.weakHypotheses.size === 1) - val rfTree = rf.weakHypotheses(0) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -175,7 +176,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { test("Binary classification with continuous features and node Id cache: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index 5cb433232e714..b184e936672ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree.impl import org.scalatest.FunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[BaggedPoint]]. */ -class BaggedPointSuite extends FunSuite with LocalSparkContext { +class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 0dbe766b4d917..88bc49cc61f94 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.util.Utils -class MLUtilsSuite extends FunSuite with LocalSparkContext { +class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala similarity index 89% rename from mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala rename to mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 7857d9e5ee5c4..b658889476d37 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -22,15 +22,15 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkConf, SparkContext} -trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => +trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ override def beforeAll() { + super.beforeAll() val conf = new SparkConf() - .setMaster("local") - .setAppName("test") + .setMaster("local[2]") + .setAppName("MLlibUnitTest") sc = new SparkContext(conf) - super.beforeAll() } override def afterAll() { diff --git a/network/common/pom.xml b/network/common/pom.xml index 8b24ebf1ba1f2..baca859fa5011 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -41,16 +41,16 @@ io.netty netty-all + + org.slf4j slf4j-api + provided - - com.google.guava guava - 11.0.2 provided diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 5fa1527ddff92..844eff4f4c701 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -31,24 +31,19 @@ import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.network.util.TransportConf; /** * A {@link ManagedBuffer} backed by a segment in a file. */ public final class FileSegmentManagedBuffer extends ManagedBuffer { - - /** - * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). - * Avoid unless there's a good reason not to. - */ - // TODO: Make this configurable - private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; - + private final TransportConf conf; private final File file; private final long offset; private final long length; - public FileSegmentManagedBuffer(File file, long offset, long length) { + public FileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length) { + this.conf = conf; this.file = file; this.offset = offset; this.length = length; @@ -65,7 +60,7 @@ public ByteBuffer nioByteBuffer() throws IOException { try { channel = new RandomAccessFile(file, "r").getChannel(); // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. - if (length < MIN_MEMORY_MAP_BYTES) { + if (length < conf.memoryMapBytes()) { ByteBuffer buf = ByteBuffer.allocate((int) length); channel.position(offset); while (buf.remaining() != 0) { @@ -134,8 +129,12 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - FileChannel fileChannel = new FileInputStream(file).getChannel(); - return new DefaultFileRegion(fileChannel, offset, length); + if (conf.lazyFileDescriptor()) { + return new LazyFileRegion(file, offset, length); + } else { + FileChannel fileChannel = new FileInputStream(file).getChannel(); + return new DefaultFileRegion(fileChannel, offset, length); + } } public File getFile() { return file; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java new file mode 100644 index 0000000000000..81bc8ec40fc82 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.FileInputStream; +import java.io.File; +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Objects; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; + +import org.apache.spark.network.util.JavaUtils; + +/** + * A FileRegion implementation that only creates the file descriptor when the region is being + * transferred. This cannot be used with Epoll because there is no native support for it. + * + * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we + * should push this into Netty so the native Epoll transport can support this feature. + */ +public final class LazyFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final File file; + private final long position; + private final long count; + + private FileChannel channel; + + private long numBytesTransferred = 0L; + + /** + * @param file file to transfer. + * @param position start position for the transfer. + * @param count number of bytes to transfer starting from position. + */ + public LazyFileRegion(File file, long position, long count) { + this.file = file; + this.position = position; + this.count = count; + } + + @Override + protected void deallocate() { + JavaUtils.closeQuietly(channel); + } + + @Override + public long position() { + return position; + } + + @Override + public long transfered() { + return numBytesTransferred; + } + + @Override + public long count() { + return count; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (channel == null) { + channel = new FileInputStream(file).getChannel(); + } + + long count = this.count - position; + if (count < 0 || position < 0) { + throw new IllegalArgumentException( + "position out of range: " + position + " (expected: 0 - " + (count - 1) + ')'); + } + + if (count == 0) { + return 0L; + } + + long written = channel.transferTo(this.position + position, count, target); + if (written > 0) { + numBytesTransferred += written; + } + return written; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("position", position) + .add("count", count) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 4e944114e8176..37f2e34ceb24d 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -49,7 +49,7 @@ * to perform this setup. * * For example, a typical workflow might be: - * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 + * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 * client.fetchChunk(streamId = 100, chunkIndex = 0, callback) * client.fetchChunk(streamId = 100, chunkIndex = 1, callback) * ... diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 397d3a8455c86..9afd5decd5e6b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -19,7 +19,6 @@ import java.io.Closeable; import java.io.IOException; -import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.List; @@ -37,7 +36,6 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; -import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -67,6 +65,7 @@ public class TransportClientFactory implements Closeable { private final Class socketChannelClass; private EventLoopGroup workerGroup; + private PooledByteBufAllocator pooledAllocator; public TransportClientFactory( TransportContext context, @@ -80,6 +79,8 @@ public TransportClientFactory( this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); // TODO: Make thread pool name configurable. this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + this.pooledAllocator = NettyUtils.createPooledByteBufAllocator( + conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); } /** @@ -115,10 +116,8 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, true) .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()); - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) + .option(ChannelOption.ALLOCATOR, pooledAllocator); final AtomicReference clientRef = new AtomicReference(); @@ -190,34 +189,4 @@ public void close() { workerGroup = null; } } - - /** - * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches - * are disabled because the ByteBufs are allocated by the event loop thread, but released by the - * executor thread rather than the event loop thread. Those thread-local caches actually delay - * the recycling of buffers, leading to larger memory usage. - */ - private PooledByteBufAllocator createPooledByteBufAllocator() { - return new PooledByteBufAllocator( - conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(), - getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), - getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - 0, // tinyCacheSize - 0, // smallCacheSize - 0 // normalCacheSize - ); - } - - /** Used to get defaults from Netty's private static fields. */ - private int getPrivateStaticField(String name) { - try { - Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); - f.setAccessible(true); - return f.getInt(null); - } catch (Exception e) { - throw new RuntimeException(e); - } - } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 731d48d4d9c6c..a6d390e13f396 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -29,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; /** - * StreamManager which allows registration of an Iterator, which are individually + * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually * fetched as chunks by the client. Each registered buffer is one chunk. */ public class OneForOneStreamManager extends StreamManager { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 579676c2c3564..625c3257d764e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -72,8 +72,8 @@ private void init(int portToBind) { NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); EventLoopGroup workerGroup = bossGroup; - PooledByteBufAllocator allocator = new PooledByteBufAllocator( - conf.preferDirectBufs() && PlatformDependent.directBufferPreferred()); + PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator( + conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads()); bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java index 63ca43c046525..57113ed12d414 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -27,7 +27,7 @@ * Wraps a {@link InputStream}, limiting the number of bytes which can be read. * * This code is from Guava's 14.0 source code, because there is no compatible way to - * use this functionality in both a Guava 11 environment and a Guava >14 environment. + * use this functionality in both a Guava 11 environment and a Guava >14 environment. */ public final class LimitedInputStream extends FilterInputStream { private long left; diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 2a7664fe89388..2a4b88b64cdc9 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -17,9 +17,11 @@ package org.apache.spark.network.util; +import java.lang.reflect.Field; import java.util.concurrent.ThreadFactory; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; @@ -32,6 +34,7 @@ import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.util.internal.PlatformDependent; /** * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. @@ -96,11 +99,47 @@ public static ByteToMessageDecoder createFrameDecoder() { return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); } - /** Returns the remote address on the channel or "" if none exists. */ + /** Returns the remote address on the channel or "<remote address>" if none exists. */ public static String getRemoteAddress(Channel channel) { if (channel != null && channel.remoteAddress() != null) { return channel.remoteAddress().toString(); } return ""; } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled for TransportClients because the ByteBufs are allocated by the event loop thread, + * but released by the executor thread rather than the event loop thread. Those thread-local + * caches actually delay the recycling of buffers, leading to larger memory usage. + */ + public static PooledByteBufAllocator createPooledByteBufAllocator( + boolean allowDirectBufs, + boolean allowCache, + int numCores) { + if (numCores == 0) { + numCores = Runtime.getRuntime().availableProcessors(); + } + return new PooledByteBufAllocator( + allowDirectBufs && PlatformDependent.directBufferPreferred(), + Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores), + Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0, + allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0, + allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0 + ); + } + + /** Used to get defaults from Netty's private static fields. */ + private static int getPrivateStaticField(String name) { + try { + Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.getInt(null); + } catch (Exception e) { + throw new RuntimeException(e); + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 787a8f0031af1..1af40acf8b4af 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -72,7 +72,24 @@ public int connectionTimeoutMs() { /** * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. - * Only relevant if maxIORetries > 0. + * Only relevant if maxIORetries > 0. */ public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); } + + /** + * Minimum size of a block that we should start using memory map rather than reading in through + * normal IO operations. This prevents Spark from memory mapping very small blocks. In general, + * memory mapping has high overhead for blocks close to or below the page size of the OS. + */ + public int memoryMapBytes() { + return conf.getInt("spark.storage.memoryMapThreshold", 2 * 1024 * 1024); + } + + /** + * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * created only when data is going to be transferred. This can reduce the number of open files. + */ + public boolean lazyFileDescriptor() { + return conf.getBoolean("spark.shuffle.io.lazyFD", true); + } } diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index c4158833976aa..dfb7740344ed0 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -63,6 +63,8 @@ public class ChunkFetchIntegrationSuite { static ManagedBuffer bufferChunk; static ManagedBuffer fileChunk; + private TransportConf transportConf; + @BeforeClass public static void setUp() throws Exception { int bufSize = 100000; @@ -80,9 +82,10 @@ public static void setUp() throws Exception { new Random().nextBytes(fileContent); fp.write(fileContent); fp.close(); - fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { @@ -90,7 +93,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { if (chunkIndex == BUFFER_CHUNK_INDEX) { return new NioManagedBuffer(buf); } else if (chunkIndex == FILE_CHUNK_INDEX) { - return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); } else { throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); } diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 27c8467687f10..12468567c3aed 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,26 +39,26 @@ org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} + + org.slf4j slf4j-api + provided - - com.google.guava guava - 11.0.2 provided org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} test-jar test diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index a6db4b2abd6c9..46ca9708621b9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; +import org.apache.spark.network.util.TransportConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,8 +49,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler { private final ExternalShuffleBlockManager blockManager; private final OneForOneStreamManager streamManager; - public ExternalShuffleBlockHandler() { - this(new OneForOneStreamManager(), new ExternalShuffleBlockManager()); + public ExternalShuffleBlockHandler(TransportConf conf) { + this(new OneForOneStreamManager(), new ExternalShuffleBlockManager(conf)); } /** Enables mocking out the StreamManager and BlockManager. */ diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index ffb7faa3dbdca..dfe0ba0595090 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -37,6 +37,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; /** * Manages converting shuffle BlockIds into physical segments of local files, from a process outside @@ -56,14 +57,17 @@ public class ExternalShuffleBlockManager { // Single-threaded Java executor used to perform expensive recursive directory deletion. private final Executor directoryCleaner; - public ExternalShuffleBlockManager() { + private final TransportConf conf; + + public ExternalShuffleBlockManager(TransportConf conf) { // TODO: Give this thread a name. - this(Executors.newSingleThreadExecutor()); + this(conf, Executors.newSingleThreadExecutor()); } // Allows tests to have more control over when directories are cleaned up. @VisibleForTesting - ExternalShuffleBlockManager(Executor directoryCleaner) { + ExternalShuffleBlockManager(TransportConf conf, Executor directoryCleaner) { + this.conf = conf; this.executors = Maps.newConcurrentMap(); this.directoryCleaner = directoryCleaner; } @@ -167,7 +171,7 @@ private void deleteExecutorDirs(String[] dirs) { // TODO: Support consolidated hash shuffle files private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); - return new FileSegmentManagedBuffer(shuffleFile, 0, shuffleFile.length()); + return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); } /** @@ -187,6 +191,7 @@ private ManagedBuffer getSortBasedShuffleBlockData( long offset = in.readLong(); long nextOffset = in.readLong(); return new FileSegmentManagedBuffer( + conf, getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.data"), offset, diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java index 60485bace643c..62fce9b0d16cd 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -23,6 +23,7 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to read a set of blocks. Returns {@link StreamHandle}. */ public class OpenBlocks extends BlockTransferMessage { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index 38acae3b31d64..7eb4385044077 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** * Initial registration message between an executor and its local shuffle server. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index 21369c8cfb0d6..bc9daa6158ba3 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -17,11 +17,11 @@ package org.apache.spark.network.shuffle.protocol; -import java.io.Serializable; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java index 38abe29cc585f..0b23e112bd512 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -23,6 +23,8 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ public class UploadBlock extends BlockTransferMessage { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java index da54797e8923c..dad6428a836fc 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java @@ -22,6 +22,8 @@ import java.io.InputStreamReader; import com.google.common.io.CharStreams; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -37,6 +39,8 @@ public class ExternalShuffleBlockManagerSuite { static TestShuffleDataContext dataContext; + static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + @BeforeClass public static void beforeAll() throws IOException { dataContext = new TestShuffleDataContext(2, 5); @@ -56,7 +60,7 @@ public static void afterAll() { @Test public void testBadRequests() { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); // Unregistered executor try { manager.getBlockData("app0", "exec1", "shuffle_1_1_0"); @@ -87,7 +91,7 @@ public void testBadRequests() { @Test public void testSortShuffleBlocks() throws IOException { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); manager.registerExecutor("app0", "exec0", dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); @@ -106,7 +110,7 @@ public void testSortShuffleBlocks() throws IOException { @Test public void testHashShuffleBlocks() throws IOException { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); manager.registerExecutor("app0", "exec0", dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index c8ece3bc53ac3..254e3a7a32b98 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -25,20 +25,23 @@ import com.google.common.util.concurrent.MoreExecutors; import org.junit.Test; - import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { TestShuffleDataContext dataContext = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); manager.applicationRemoved("app", false /* cleanup */); @@ -61,7 +64,7 @@ public void cleanupUsesExecutor() throws IOException { @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } }; - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(noThreadExecutor); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, noThreadExecutor); manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); manager.applicationRemoved("app", true); @@ -78,7 +81,7 @@ public void cleanupMultipleExecutors() throws IOException { TestShuffleDataContext dataContext0 = createSomeData(); TestShuffleDataContext dataContext1 = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); @@ -93,7 +96,7 @@ public void cleanupOnlyRemovedApp() throws IOException { TestShuffleDataContext dataContext0 = createSomeData(); TestShuffleDataContext dataContext1 = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(sameThreadExecutor); + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 687bde59fdae4..02c10bcb7b261 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -92,7 +92,7 @@ public static void beforeAll() throws IOException { dataContext1.insertHashShuffleData(1, 0, exec1Blocks); conf = new TransportConf(new SystemPropertyConfigProvider()); - handler = new ExternalShuffleBlockHandler(); + handler = new ExternalShuffleBlockHandler(conf); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 8afceab1d585a..759a12910c94d 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -42,7 +42,7 @@ public class ExternalShuffleSecuritySuite { @Before public void beforeEach() { - RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(), + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf), new TestSecretKeyHolder("my-app-id", "secret")); TransportContext context = new TransportContext(conf, handler); this.server = context.createServer(); diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 6e6f6f3e79296..acec8f18f2b5c 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,7 +39,7 @@ org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_${scala.binary.version} ${project.version} @@ -54,5 +54,38 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index bb0b8f7e6cba6..a34aabe9e78a6 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -95,10 +95,11 @@ private boolean isAuthenticationEnabled() { */ @Override protected void serviceInit(Configuration conf) { + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - RpcHandler rpcHandler = new ExternalShuffleBlockHandler(); + RpcHandler rpcHandler = new ExternalShuffleBlockHandler(transportConf); if (authEnabled) { secretManager = new ShuffleSecretManager(); rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); @@ -106,7 +107,6 @@ protected void serviceInit(Configuration conf) { int port = conf.getInt( SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); TransportContext transportContext = new TransportContext(transportConf, rpcHandler); shuffleServer = transportContext.createServer(port); String authEnabledString = authEnabled ? "enabled" : "not enabled"; diff --git a/pom.xml b/pom.xml index 88ef67c515b3a..b7df53d3e5eb1 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -97,30 +97,26 @@ sql/catalyst sql/core sql/hive - repl assembly external/twitter - external/kafka external/flume external/flume-sink - external/zeromq external/mqtt + external/zeromq examples + repl UTF-8 UTF-8 - + org.spark-project.akka + 2.3.4-spark 1.6 spark - 2.10.4 - 2.10 2.0.1 0.18.1 shaded-protobuf - org.spark-project.akka - 2.3.4-spark 1.7.5 1.2.17 1.0.4 @@ -137,7 +133,7 @@ 1.6.0rc3 1.2.3 8.1.14.v20131031 - 0.3.6 + 0.5.0 3.0.0 1.7.6 @@ -146,9 +142,13 @@ 1.1.0 4.2.6 3.1.1 - + ${project.build.directory}/spark-test-classpath.txt 64m 512m + 2.10.4 + 2.10 + ${scala.version} + org.scala-lang @@ -230,23 +230,11 @@ false + - - spark-staging - Spring Staging Repository - https://oss.sonatype.org/content/repositories/orgspark-project-1085 - - true - - - false - - - - - spark-staging-hive13 - Spring Staging Repository Hive 13 - https://oss.sonatype.org/content/repositories/orgspark-project-1089/ + spark-staging-1038 + Spark 1.2.0 Staging (1038) + https://repository.apache.org/content/repositories/orgapachespark-1038/ true @@ -267,19 +255,66 @@ - - org.spark-project.spark unused 1.0.0 + + + org.codehaus.groovy + groovy-all + 2.3.7 + provided + + + ${jline.groupid} + jline + ${jline.version} + + + com.twitter + chill_${scala.binary.version} + ${chill.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + + + com.twitter + chill-java + ${chill.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + org.eclipse.jetty jetty-util @@ -366,7 +401,7 @@ org.xerial.snappy snappy-java - 1.1.1.3 + 1.1.1.6 net.jpountz.lz4 @@ -395,36 +430,6 @@ protobuf-java ${protobuf.version} - - com.twitter - chill_${scala.binary.version} - ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - - - - com.twitter - chill-java - ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - - ${akka.group} akka-actor_${scala.binary.version} @@ -512,11 +517,6 @@ scala-reflect ${scala.version} - - org.scala-lang - jline - ${scala.version} - org.scala-lang scala-library @@ -965,6 +965,9 @@ ${session.executionRootDirectory} 1 false + false + ${test_classpath} + true @@ -1022,10 +1025,56 @@ + + org.apache.maven.plugins + maven-javadoc-plugin + 2.10.1 + + + + org.apache.maven.plugins + maven-dependency-plugin + 2.9 + + + test-compile + + build-classpath + + + test + ${test_classpath_file} + + + + + + + + org.codehaus.gmavenplus + gmavenplus-plugin + 1.2 + + + process-test-classes + + execute + + + + + + + + + org.apache.maven.plugins @@ -1174,6 +1223,25 @@
    + + doclint-java8-disable + + [1.8,) + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + -Xdoclint:all -Xdoclint:-missing + + + + + + hadoop-provided - - false - org.apache.hadoop @@ -1336,19 +1394,13 @@ - hive - - false - + hive-thriftserver sql/hive-thriftserver hive-0.12.0 - - false - 0.12.0-protobuf-2.5 0.12.0 @@ -1357,14 +1409,41 @@ hive-0.13.1 - - false - 0.13.1a 0.13.1 10.10.1.1 + + + scala-2.10 + + !scala-2.11 + + + 2.10.4 + 2.10 + ${scala.version} + org.scala-lang + + + external/kafka + + + + + scala-2.11 + + scala-2.11 + + + 2.11.2 + 2.11 + 2.12 + jline + + + diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index d919b18e09855..f0cbf4e57b8c5 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -30,7 +30,7 @@ object MimaBuild { def excludeMember(fullName: String) = Seq( ProblemFilters.exclude[MissingMethodProblem](fullName), - // Sometimes excluded methods have default arguments and + // Sometimes excluded methods have default arguments and // they are translated into public methods/fields($default$) in generated // bytecode. It is not possible to exhaustively list everything. // But this should be okay. @@ -91,9 +91,9 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.1.0" + val previousSparkVersion = "1.2.0" val fullId = "spark-" + projectRef.project + "_2.10" - mimaDefaultSettings ++ + mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a94d09be3bec6..230239aa40500 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -33,6 +33,28 @@ import com.typesafe.tools.mima.core._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.3") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in the 1.2 build. + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ) ++ Seq( + // SPARK-2321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfoImpl.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfo.submissionTime") + ) ++ Seq( + // SPARK-4614 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.randn"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.rand") + ) + case v if v.startsWith("1.2") => Seq( MimaBuild.excludeSparkPackage("deploy"), @@ -85,6 +107,10 @@ object MimaExcludes { "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), ProblemFilters.exclude[MissingTypesProblem]( "org.apache.spark.rdd.PairRDDFunctions") + ) ++ Seq( + // SPARK-4062 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") ) case v if v.startsWith("1.1") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 657e4b4432775..b16ed66aeb3c3 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -31,19 +31,19 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq) = + sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, + streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, networkYarn, java8Tests, - sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "network-yarn", + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, + sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") - .map(ProjectRef(buildLocation, _)) + val assemblyProjects@Seq(assembly, examples, networkYarn) = + Seq("assembly", "examples", "network-yarn").map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. @@ -68,8 +68,8 @@ object SparkBuild extends PomBuild { profiles ++= Seq("spark-ganglia-lgpl") } if (Properties.envOrNone("SPARK_HIVE").isDefined) { - println("NOTE: SPARK_HIVE is deprecated, please use -Phive flag.") - profiles ++= Seq("hive") + println("NOTE: SPARK_HIVE is deprecated, please use -Phive and -Phive-thriftserver flags.") + profiles ++= Seq("hive", "hive-thriftserver") } Properties.envOrNone("SPARK_HADOOP_VERSION") match { case Some(v) => @@ -91,13 +91,23 @@ object SparkBuild extends PomBuild { profiles } - override val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { + override val profiles = { + val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { case None => backwardCompatibility case Some(v) => if (backwardCompatibility.nonEmpty) println("Note: We ignore environment variables, when use of profile is detected in " + "conjunction with environment variable.") v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq + } + + if (System.getProperty("scala-2.11") == "") { + // To activate scala-2.11 profile, replace empty property value to non-empty value + // in the same way as Maven which handles -Dname as -Dname=true before executes build process. + // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 + System.setProperty("scala-2.11", "true") + } + profiles } Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { @@ -126,7 +136,12 @@ object SparkBuild extends PomBuild { }, publishMavenStyle in MavenCompile := true, publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal), - publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn + publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn, + + javacOptions in (Compile, doc) ++= { + val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3) + if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty + } ) def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { @@ -136,7 +151,8 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ - (allProjects ++ optionallyEnabledProjects ++ assemblyProjects).foreach(enable(sharedSettings)) + (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) + .foreach(enable(sharedSettings ++ ExludedDependencies.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -178,6 +194,16 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + This excludes library dependencies in sbt, which are specified in maven but are + not needed by sbt build. + */ +object ExludedDependencies { + lazy val settings = Seq( + libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } + ) +} + /** * Following project only exists to pull previous artifacts of Spark for generating * Mima ignores. For more information see: SPARK 2071 @@ -188,12 +214,14 @@ object OldDeps { def versionArtifact(id: String): Option[sbt.ModuleID] = { val fullId = id + "_2.10" - Some("org.apache.spark" % fullId % "1.1.0") + Some("org.apache.spark" % fullId % "1.2.0") } def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", scalaVersion := "2.10.4", + // TODO: remove this as soon as 1.2.0 is published on Maven central. + resolvers += "spark-staging-1038" at "https://repository.apache.org/content/repositories/orgapachespark-1038/", retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", @@ -270,8 +298,15 @@ object Assembly { lazy val settings = assemblySettings ++ Seq( test in assembly := {}, - jarName in assembly <<= (version, moduleName) map { (v, mName) => mName + "-"+v + "-hadoop" + - Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" }, + jarName in assembly <<= (version, moduleName) map { (v, mName) => + if (mName.contains("network-yarn")) { + // This must match the same name used in maven (see network/yarn/pom.xml) + "spark-" + v + "-yarn-shuffle.jar" + } else { + mName + "-" + v + "-hadoop" + + Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" + } + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -302,7 +337,7 @@ object Unidoc { unidocProjectFilter in(ScalaUnidoc, unidoc) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { @@ -348,13 +383,18 @@ object TestSettings { javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", javaOptions in Test += "-Dspark.ui.enabled=false", + javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", + javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, + // This places test scope jars on the classpath of executors during tests. + javaOptions in Test += + "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files. + map(_.getAbsolutePath).mkString(":").stripSuffix(":"), javaOptions += "-Xmx3g", - // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 3ef2d5451da0d..8863f272da415 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -26,7 +26,7 @@ import sbt.Keys._ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) - lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git") + lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") // There is actually no need to publish this artifact. def styleSettings = Defaults.defaultSettings ++ Seq ( diff --git a/python/docs/epytext.py b/python/docs/epytext.py index 19fefbfc057a4..e884d5e6b19c7 100644 --- a/python/docs/epytext.py +++ b/python/docs/epytext.py @@ -1,7 +1,7 @@ import re RULES = ( - (r"<[\w.]+>", r""), + (r"<(!BLANKLINE)[\w.]+>", r""), (r"L{([\w.()]+)}", r":class:`\1`"), (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), (r"C{([\w.()]+)}", r":class:`\1`"), diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index e39e6514d77a1..9556e4718e585 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -37,16 +37,6 @@ """ -# The following block allows us to import python's random instead of mllib.random for scripts in -# mllib that depend on top level pyspark packages, which transitively depend on python's random. -# Since Python's import logic looks for modules in the current package first, we eliminate -# mllib.random as a candidate for C{import random} by removing the first search path, the script's -# location, in order to force the loader to look in Python's top-level modules for C{random}. -import sys -s = sys.path.pop(0) -import random -sys.path.insert(0, s) - from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index f124dc6c07575..6b8a8b256a891 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -15,21 +15,10 @@ # limitations under the License. # -""" ->>> from pyspark.context import SparkContext ->>> sc = SparkContext('local', 'test') ->>> b = sc.broadcast([1, 2, 3, 4, 5]) ->>> b.value -[1, 2, 3, 4, 5] ->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() -[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] ->>> b.unpersist() - ->>> large_broadcast = sc.broadcast(list(range(10000))) -""" import os - -from pyspark.serializers import CompressedSerializer, PickleSerializer +import cPickle +import gc +from tempfile import NamedTemporaryFile __all__ = ['Broadcast'] @@ -49,44 +38,88 @@ def _from_id(bid): class Broadcast(object): """ - A broadcast variable created with - L{SparkContext.broadcast()}. + A broadcast variable created with L{SparkContext.broadcast()}. Access its value through C{.value}. + + Examples: + + >>> from pyspark.context import SparkContext + >>> sc = SparkContext('local', 'test') + >>> b = sc.broadcast([1, 2, 3, 4, 5]) + >>> b.value + [1, 2, 3, 4, 5] + >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + >>> b.unpersist() + + >>> large_broadcast = sc.broadcast(range(10000)) """ - def __init__(self, bid, value, java_broadcast=None, - pickle_registry=None, path=None): + def __init__(self, sc=None, value=None, pickle_registry=None, path=None): """ - Should not be called directly by users -- use - L{SparkContext.broadcast()} + Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.bid = bid - if path is None: - self._value = value - self._jbroadcast = java_broadcast - self._pickle_registry = pickle_registry - self.path = path + if sc is not None: + f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) + self._path = self.dump(value, f) + self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path) + self._pickle_registry = pickle_registry + else: + self._jbroadcast = None + self._path = path + + def dump(self, value, f): + if isinstance(value, basestring): + if isinstance(value, unicode): + f.write('U') + value = value.encode('utf8') + else: + f.write('S') + f.write(value) + else: + f.write('P') + cPickle.dump(value, f, 2) + f.close() + return f.name + + def load(self, path): + with open(path, 'rb', 1 << 20) as f: + flag = f.read(1) + data = f.read() + if flag == 'P': + # cPickle.loads() may create lots of objects, disable GC + # temporary for better performance + gc.disable() + try: + return cPickle.loads(data) + finally: + gc.enable() + else: + return data.decode('utf8') if flag == 'U' else data @property def value(self): """ Return the broadcasted value """ - if not hasattr(self, "_value") and self.path is not None: - ser = CompressedSerializer(PickleSerializer()) - self._value = ser.load_stream(open(self.path)).next() + if not hasattr(self, "_value") and self._path is not None: + self._value = self.load(self._path) return self._value def unpersist(self, blocking=False): """ Delete cached copies of this broadcast on the executors. """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) - os.unlink(self.path) + os.unlink(self._path) def __reduce__(self): + if self._jbroadcast is None: + raise Exception("Broadcast can only be serialized in driver") self._pickle_registry.add(self) - return (_from_id, (self.bid, )) + return _from_id, (self._jbroadcast.id(),) if __name__ == "__main__": diff --git a/python/pyspark/context.py b/python/pyspark/context.py index faa5952258aef..ed7351d60cff2 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer + PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -289,12 +289,29 @@ def stop(self): def parallelize(self, c, numSlices=None): """ - Distribute a local Python collection to form an RDD. + Distribute a local Python collection to form an RDD. Using xrange + is recommended if the input represents a range for performance. - >>> sc.parallelize(range(5), 5).glom().collect() - [[0], [1], [2], [3], [4]] + >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect() + [[0], [2], [3], [4], [6]] + >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect() + [[], [0], [], [2], [4]] """ - numSlices = numSlices or self.defaultParallelism + numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism + if isinstance(c, xrange): + size = len(c) + if size == 0: + return self.parallelize([], numSlices) + step = c[1] - c[0] if size > 1 else 1 + start0 = c[0] + + def getStart(split): + return start0 + (split * size / numSlices) * step + + def f(split, iterator): + return xrange(getStart(split), getStart(split + 1), step) + + return self.parallelize([], numSlices).mapPartitionsWithIndex(f) # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). @@ -607,14 +624,7 @@ def broadcast(self, value): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - ser = CompressedSerializer(PickleSerializer()) - # pass large object by py4j is very slow and need much memory - tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - ser.dump_stream([value], tempFile) - tempFile.close() - jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) - return Broadcast(jbroadcast.id(), None, jbroadcast, - self._pickled_broadcast_vars, tempFile.name) + return Broadcast(self, value, self._pickled_broadcast_vars) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 9c70fa5c16d0c..a975dc19cb78e 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -45,7 +45,9 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + env = dict(os.environ) + env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdout=PIPE, stdin=PIPE) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 4149f54931d1f..5030a655fcbba 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -24,3 +24,37 @@ import numpy if numpy.version.version < '1.4': raise Exception("MLlib requires NumPy 1.4+") + +__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', + 'recommendation', 'regression', 'stat', 'tree', 'util'] + +import sys +import rand as random +random.__name__ = 'random' +random.RandomRDDs.__module__ = __name__ + '.random' + + +class RandomModuleHook(object): + """ + Hook to import pyspark.mllib.random + """ + fullname = __name__ + '.random' + + def find_module(self, name, path=None): + # skip all other modules + if not name.startswith(self.fullname): + return + return self + + def load_module(self, name): + if name == self.fullname: + return random + + cname = name.rsplit('.', 1)[-1] + try: + return getattr(random, cname) + except AttributeError: + raise ImportError + + +sys.meta_path.append(RandomModuleHook()) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 297a2bf37d2cf..f14d0ed11cbbb 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -20,96 +20,200 @@ import numpy from numpy import array +from pyspark import RDD from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper -__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel', - 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] +__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', + 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] -class LogisticRegressionModel(LinearModel): +class LinearBinaryClassificationModel(LinearModel): + """ + Represents a linear binary classification model that predicts to whether an + example is positive (1.0) or negative (0.0). + """ + def __init__(self, weights, intercept): + super(LinearBinaryClassificationModel, self).__init__(weights, intercept) + self._threshold = None + + def setThreshold(self, value): + """ + :: Experimental :: + + Sets the threshold that separates positive predictions from negative + predictions. An example with prediction score greater than or equal + to this threshold is identified as an positive, and negative otherwise. + """ + self._threshold = value + + def clearThreshold(self): + """ + :: Experimental :: + + Clears the threshold so that `predict` will output raw prediction scores. + """ + self._threshold = None + + def predict(self, test): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + raise NotImplementedError + + +class LogisticRegressionModel(LinearBinaryClassificationModel): """A linear binary classification model derived from logistic regression. >>> data = [ - ... LabeledPoint(0.0, [0.0]), - ... LabeledPoint(1.0, [1.0]), - ... LabeledPoint(1.0, [2.0]), - ... LabeledPoint(1.0, [3.0]) + ... LabeledPoint(0.0, [0.0, 1.0]), + ... LabeledPoint(1.0, [1.0, 0.0]), ... ] >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data)) - >>> lrm.predict(array([1.0])) > 0 - True - >>> lrm.predict(array([0.0])) <= 0 - True + >>> lrm.predict([1.0, 0.0]) + 1 + >>> lrm.predict([0.0, 1.0]) + 0 + >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect() + [1, 0] + >>> lrm.clearThreshold() + >>> lrm.predict([0.0, 1.0]) + 0.123... + >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), - ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data)) - >>> lrm.predict(array([0.0, 1.0])) > 0 - True - >>> lrm.predict(array([0.0, 0.0])) <= 0 - True - >>> lrm.predict(SparseVector(2, {1: 1.0})) > 0 - True - >>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0 - True + >>> lrm.predict(array([0.0, 1.0])) + 1 + >>> lrm.predict(array([1.0, 0.0])) + 0 + >>> lrm.predict(SparseVector(2, {1: 1.0})) + 1 + >>> lrm.predict(SparseVector(2, {0: 1.0})) + 0 """ + def __init__(self, weights, intercept): + super(LogisticRegressionModel, self).__init__(weights, intercept) + self._threshold = 0.5 def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + if isinstance(x, RDD): + return x.map(lambda v: self.predict(v)) + + x = _convert_to_vector(x) margin = self.weights.dot(x) + self._intercept if margin > 0: prob = 1 / (1 + exp(-margin)) else: exp_margin = exp(margin) prob = exp_margin / (1 + exp_margin) - return 1 if prob > 0.5 else 0 + if self._threshold is None: + return prob + else: + return 1 if prob > self._threshold else 0 class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType="none", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False): """ Train a logistic regression model on the given data. - :param data: The training data. + :param data: The training data, an RDD of LabeledPoint. :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.01). :param regType: The type of regularizer used for training our model. :Allowed values: - - "l1" for using L1Updater - - "l2" for using SquaredL2Updater - - "none" for no regularizer + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization - (default: "none") + (default: "l2") - @param intercept: Boolean parameter which indicates the use + :param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, iterations, step, - miniBatchFraction, i, regParam, regType, intercept) + return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), + float(step), float(miniBatchFraction), i, float(regParam), regType, + bool(intercept)) + + return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) + + +class LogisticRegressionWithLBFGS(object): + + @classmethod + def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", + intercept=False, corrections=10, tolerance=1e-4): + """ + Train a logistic regression model on the given data. + + :param data: The training data, an RDD of LabeledPoint. + :param iterations: The number of iterations (default: 100). + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 0.01). + :param regType: The type of regularizer used for training + our model. + + :Allowed values: + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization + + (default: "l2") + + :param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + :param corrections: The number of corrections used in the LBFGS + update (default: 10). + :param tolerance: The convergence tolerance of iterations for + L-BFGS (default: 1e-4). + + >>> data = [ + ... LabeledPoint(0.0, [0.0, 1.0]), + ... LabeledPoint(1.0, [1.0, 0.0]), + ... ] + >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data)) + >>> lrm.predict([1.0, 0.0]) + 1 + >>> lrm.predict([0.0, 1.0]) + 0 + """ + def train(rdd, i): + return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, + float(regParam), str(regType), bool(intercept), int(corrections), + float(tolerance)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) -class SVMModel(LinearModel): +class SVMModel(LinearBinaryClassificationModel): """A support vector machine. @@ -120,8 +224,14 @@ class SVMModel(LinearModel): ... LabeledPoint(1.0, [3.0]) ... ] >>> svm = SVMWithSGD.train(sc.parallelize(data)) - >>> svm.predict(array([1.0])) > 0 - True + >>> svm.predict([1.0]) + 1 + >>> svm.predict(sc.parallelize([[1.0]])).collect() + [1] + >>> svm.clearThreshold() + >>> svm.predict(array([1.0])) + 1.25... + >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: -1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), @@ -129,30 +239,44 @@ class SVMModel(LinearModel): ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data)) - >>> svm.predict(SparseVector(2, {1: 1.0})) > 0 - True - >>> svm.predict(SparseVector(2, {0: -1.0})) <= 0 - True + >>> svm.predict(SparseVector(2, {1: 1.0})) + 1 + >>> svm.predict(SparseVector(2, {0: -1.0})) + 0 """ + def __init__(self, weights, intercept): + super(SVMModel, self).__init__(weights, intercept) + self._threshold = 0.0 def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + if isinstance(x, RDD): + return x.map(lambda v: self.predict(v)) + + x = _convert_to_vector(x) margin = self.weights.dot(x) + self.intercept - return 1 if margin >= 0 else 0 + if self._threshold is None: + return margin + else: + return 1 if margin > self._threshold else 0 class SVMWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None, regType="none", intercept=False): + def train(cls, data, iterations=100, step=1.0, regParam=0.01, + miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): """ Train a support vector machine on the given data. - :param data: The training data. + :param data: The training data, an RDD of LabeledPoint. :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.01). :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). @@ -160,20 +284,21 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, our model. :Allowed values: - - "l1" for using L1Updater - - "l2" for using SquaredL2Updater, - - "none" for no regularizer. + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization - (default: "none") + (default: "l2") - @param intercept: Boolean parameter which indicates the use + :param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainSVMModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i, regType, intercept) + return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i, regType, + bool(intercept)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) @@ -197,6 +322,8 @@ class NaiveBayesModel(object): 0.0 >>> model.predict(array([1.0, 0.0])) 1.0 + >>> model.predict(sc.parallelize([[1.0, 0.0]])).collect() + [1.0] >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {1: 0.0})), ... LabeledPoint(0.0, SparseVector(2, {1: 1.0})), @@ -215,7 +342,9 @@ def __init__(self, labels, pi, theta): self.theta = theta def predict(self, x): - """Return the most likely class for a data vector x""" + """Return the most likely class for a data vector or an RDD of vectors""" + if isinstance(x, RDD): + return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] @@ -233,11 +362,12 @@ def train(cls, data, lambda_=1.0): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - :param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector - (e.g. a count vector). + :param data: RDD of LabeledPoint. :param lambda_: The smoothing parameter """ + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("`data` should be an RDD of LabeledPoint") labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) @@ -245,7 +375,8 @@ def train(cls, data, lambda_=1.0): def _test(): import doctest from pyspark import SparkContext - globs = globals().copy() + import pyspark.mllib.classification + globs = pyspark.mllib.classification.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index fe4c4cc5094d8..e2492eef5bd6a 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -16,7 +16,7 @@ # from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, callJavaFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['KMeansModel', 'KMeans'] @@ -80,10 +80,8 @@ class KMeans(object): @classmethod def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" - # cache serialized data to avoid objects over head in JVM - jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True) - model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs, - initializationMode) + model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, + runs, initializationMode) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index c6149fe391ec8..33c49e2399908 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -54,15 +54,13 @@ def _new_smart_decode(obj): # this will call the MLlib version of pythonToJava() -def _to_java_object_rdd(rdd, cache=False): +def _to_java_object_rdd(rdd): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - if cache: - rdd.cache() return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 44bf6f269d7a3..8cb992df2d9c7 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -18,14 +18,17 @@ """ Python package for feature in MLlib. """ +from __future__ import absolute_import + import sys import warnings +import random from py4j.protocol import Py4JJavaError from pyspark import RDD, SparkContext from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg import Vectors, _convert_to_vector __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] @@ -81,12 +84,16 @@ def transform(self, vector): """ Applies unit length normalization on a vector. - :param vector: vector to be normalized. + :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) return callMLlibFunc("normalizeVector", self.p, vector) @@ -95,8 +102,12 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): Wrapper for the model in JVM """ - def transform(self, dataset): - return self.call("transform", dataset) + def transform(self, vector): + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) + return self.call("transform", vector) class StandardScalerModel(JavaVectorTransformer): @@ -109,7 +120,7 @@ def transform(self, vector): """ Applies standardization transformation on a vector. - :param vector: Vector to be standardized. + :param vector: Vector or RDD of Vector to be standardized. :return: Standardized vector. If the variance of a column is zero, it will return default `0.0` for the column with zero variance. """ @@ -154,6 +165,7 @@ def fit(self, dataset): the transformation model. :return: a StandardScalarModel """ + dataset = dataset.map(_convert_to_vector) jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset) return StandardScalerModel(jmodel) @@ -211,6 +223,8 @@ def transform(self, dataset): :param dataset: an RDD of term frequency vectors :return: an RDD of TF-IDF vectors """ + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") return JavaVectorTransformer.transform(self, dataset) @@ -255,7 +269,9 @@ def fit(self, dataset): :param dataset: an RDD of term frequency vectors """ - jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset) + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") + jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset.map(_convert_to_vector)) return IDFModel(jmodel) @@ -287,6 +303,8 @@ def findSynonyms(self, word, num): Note: local use only """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) @@ -326,8 +344,6 @@ def __init__(self): """ Construct Word2Vec instance """ - import random # this can't be on the top because of mllib.random - self.vectorSize = 100 self.learningRate = 0.025 self.numPartitions = 1 @@ -374,9 +390,11 @@ def fit(self, data): """ Computes the vector representation of each word in vocabulary. - :param data: training data. RDD of subtype of Iterable[String] + :param data: training data. RDD of list of string :return: Word2VecModel instance """ + if not isinstance(data, RDD): + raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), int(self.numIterations), long(self.seed)) @@ -394,8 +412,5 @@ def _test(): exit(-1) if __name__ == "__main__": - # remove current path from list of search paths to avoid importing mllib.random - # for C{import random}, which is done in an external dependency of pyspark during doctests. - import sys sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index e35202dca0acc..f7aa2b0cb04b3 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -30,7 +30,7 @@ import numpy as np from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ - IntegerType, ByteType, Row + IntegerType, ByteType __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] @@ -173,12 +173,16 @@ class DenseVector(Vector): A dense vector represented by a value array. """ def __init__(self, ar): - if not isinstance(ar, array.array): - ar = array.array('d', ar) + if isinstance(ar, basestring): + ar = np.frombuffer(ar, dtype=np.float64) + elif not isinstance(ar, np.ndarray): + ar = np.array(ar, dtype=np.float64) + if ar.dtype != np.float64: + ar.astype(np.float64) self.array = ar def __reduce__(self): - return DenseVector, (self.array,) + return DenseVector, (self.array.tostring(),) def dot(self, other): """ @@ -207,9 +211,10 @@ def dot(self, other): ... AssertionError: dimension mismatch """ - if type(other) == np.ndarray and other.ndim > 1: - assert len(self) == other.shape[0], "dimension mismatch" - return np.dot(self.toArray(), other) + if type(other) == np.ndarray: + if other.ndim > 1: + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.array, other) elif _have_scipy and scipy.sparse.issparse(other): assert len(self) == other.shape[0], "dimension mismatch" return other.transpose().dot(self.toArray()) @@ -261,7 +266,7 @@ def squared_distance(self, other): return np.dot(diff, diff) def toArray(self): - return np.array(self.array) + return self.array def __getitem__(self, item): return self.array[item] @@ -276,7 +281,7 @@ def __repr__(self): return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) def __eq__(self, other): - return isinstance(other, DenseVector) and self.array == other.array + return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) def __ne__(self, other): return not self == other @@ -314,18 +319,28 @@ def __init__(self, size, *args): if type(pairs) == dict: pairs = pairs.items() pairs = sorted(pairs) - self.indices = array.array('i', [p[0] for p in pairs]) - self.values = array.array('d', [p[1] for p in pairs]) + self.indices = np.array([p[0] for p in pairs], dtype=np.int32) + self.values = np.array([p[1] for p in pairs], dtype=np.float64) else: - assert len(args[0]) == len(args[1]), "index and value arrays not same length" - self.indices = array.array('i', args[0]) - self.values = array.array('d', args[1]) + if isinstance(args[0], basestring): + assert isinstance(args[1], str), "values should be string too" + if args[0]: + self.indices = np.frombuffer(args[0], np.int32) + self.values = np.frombuffer(args[1], np.float64) + else: + # np.frombuffer() doesn't work well with empty string in older version + self.indices = np.array([], dtype=np.int32) + self.values = np.array([], dtype=np.float64) + else: + self.indices = np.array(args[0], dtype=np.int32) + self.values = np.array(args[1], dtype=np.float64) + assert len(self.indices) == len(self.values), "index and value arrays not same length" for i in xrange(len(self.indices) - 1): if self.indices[i] >= self.indices[i + 1]: raise TypeError("indices array must be sorted") def __reduce__(self): - return (SparseVector, (self.size, self.indices, self.values)) + return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring())) def dot(self, other): """ @@ -461,8 +476,7 @@ def toArray(self): Returns a copy of this SparseVector as a 1-dimensional NumPy array. """ arr = np.zeros((self.size,), dtype=np.float64) - for i in xrange(len(self.indices)): - arr[self.indices[i]] = self.values[i] + arr[self.indices] = self.values return arr def __len__(self): @@ -493,8 +507,8 @@ def __eq__(self, other): """ return (isinstance(other, self.__class__) and other.size == self.size - and other.indices == self.indices - and other.values == self.values) + and np.array_equal(other.indices, self.indices) + and np.array_equal(other.values, self.values)) def __ne__(self, other): return not self.__eq__(other) @@ -577,25 +591,34 @@ class DenseMatrix(Matrix): """ def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) + if isinstance(values, basestring): + values = np.frombuffer(values, dtype=np.float64) + elif not isinstance(values, np.ndarray): + values = np.array(values, dtype=np.float64) assert len(values) == numRows * numCols - if not isinstance(values, array.array): - values = array.array('d', values) + if values.dtype != np.float64: + values.astype(np.float64) self.values = values def __reduce__(self): - return DenseMatrix, (self.numRows, self.numCols, self.values) + return DenseMatrix, (self.numRows, self.numCols, self.values.tostring()) def toArray(self): """ Return an numpy.ndarray - >>> arr = array.array('d', [float(i) for i in range(4)]) - >>> m = DenseMatrix(2, 2, arr) + >>> m = DenseMatrix(2, 2, range(4)) >>> m.toArray() array([[ 0., 2.], [ 1., 3.]]) """ - return np.reshape(self.values, (self.numRows, self.numCols), order='F') + return self.values.reshape((self.numRows, self.numCols), order='F') + + def __eq__(self, other): + return (isinstance(other, DenseMatrix) and + self.numRows == other.numRows and + self.numCols == other.numCols and + all(self.values == other.values)) class Matrices(object): @@ -614,8 +637,4 @@ def _test(): exit(-1) if __name__ == "__main__": - # remove current path from list of search paths to avoid importing mllib.random - # for C{import random}, which is done in an external dependency of pyspark during doctests. - import sys - sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/rand.py similarity index 69% rename from python/pyspark/mllib/random.py rename to python/pyspark/mllib/rand.py index 7eebfc6bcd894..cb4304f92152b 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/rand.py @@ -52,6 +52,12 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() >>> len(x) 100 @@ -76,6 +82,12 @@ def normalRDD(sc, size, numPartitions=None, seed=None): C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() @@ -93,6 +105,13 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). + >>> mean = 100.0 >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) >>> stats = x.stats() @@ -104,7 +123,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): >>> abs(stats.stdev() - sqrt(mean)) < 0.5 True """ - return callMLlibFunc("poissonRDD", sc._jsc, mean, size, numPartitions, seed) + return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod @toArray @@ -113,6 +132,13 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): Generates an RDD comprised of vectors containing i.i.d. samples drawn from the uniform distribution U(0.0, 1.0). + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD. + :param seed: Seed for the RNG that generates the seed for the generator in each partition. + :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. + >>> import numpy as np >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape @@ -131,6 +157,13 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): Generates an RDD comprised of vectors containing i.i.d. samples drawn from the standard normal distribution. + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + >>> import numpy as np >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) >>> mat.shape @@ -149,6 +182,14 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Poisson distribution with the input mean. + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). + >>> import numpy as np >>> mean = 100.0 >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) @@ -161,7 +202,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> abs(mat.std() - sqrt(mean)) < 0.5 True """ - return callMLlibFunc("poissonVectorRDD", sc._jsc, mean, numRows, numCols, + return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols, numPartitions, seed) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index e26b152e0cdfd..97ec74eda0b71 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,24 +15,28 @@ # limitations under the License. # +from collections import namedtuple + from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc -__all__ = ['MatrixFactorizationModel', 'ALS'] +__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] -class Rating(object): - def __init__(self, user, product, rating): - self.user = int(user) - self.product = int(product) - self.rating = float(rating) +class Rating(namedtuple("Rating", ["user", "product", "rating"])): + """ + Represents a (user, product, rating) tuple. - def __reduce__(self): - return Rating, (self.user, self.product, self.rating) + >>> r = Rating(1, 2, 5.0) + >>> (r.user, r.product, r.rating) + (1, 2, 5.0) + >>> (r[0], r[1], r[2]) + (1, 2, 5.0) + """ - def __repr__(self): - return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating) + def __reduce__(self): + return Rating, (int(self.user), int(self.product), float(self.rating)) class MatrixFactorizationModel(JavaModelWrapper): @@ -51,7 +55,7 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 1, seed=10) >>> model.predictAll(testset).collect() - [Rating(1, 1, 1), Rating(1, 2, 1)] + [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)] >>> model = ALS.train(ratings, 4, seed=10) >>> model.userFeatures().collect() @@ -79,7 +83,7 @@ class MatrixFactorizationModel(JavaModelWrapper): 0.4473... """ def predict(self, user, product): - return self._java_model.predict(user, product) + return self._java_model.predict(int(user), int(product)) def predictAll(self, user_product): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" @@ -106,7 +110,7 @@ def _prepare(cls, ratings): ratings = ratings.map(lambda x: Rating(*x)) else: raise ValueError("rating should be RDD of Rating or tuple/list") - return _to_java_object_rdd(ratings, True) + return ratings @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 43c1a2fc101dd..210060140fd91 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,7 +18,7 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', @@ -36,7 +36,7 @@ class LabeledPoint(object): """ def __init__(self, label, features): - self.label = label + self.label = float(label) self.features = _convert_to_vector(features) def __reduce__(self): @@ -46,7 +46,7 @@ def __str__(self): return "(" + ",".join((str(self.label), str(self.features))) + ")" def __repr__(self): - return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")" + return "LabeledPoint(%s, %s)" % (self.label, self.features) class LinearModel(object): @@ -55,7 +55,7 @@ class LinearModel(object): def __init__(self, weights, intercept): self._coeff = _convert_to_vector(weights) - self._intercept = intercept + self._intercept = float(intercept) @property def weights(self): @@ -66,7 +66,7 @@ def intercept(self): return self._intercept def __repr__(self): - return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept) + return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) class LinearRegressionModelBase(LinearModel): @@ -85,6 +85,7 @@ def predict(self, x): Predict the value of the dependent variable given a vector x containing values for the independent variables. """ + x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept @@ -124,9 +125,11 @@ class LinearRegressionModel(LinearRegressionModelBase): # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) initial_weights = initial_weights or [0.0] * len(data.first().features) - weights, intercept = train_func(_to_java_object_rdd(data, cache=True), - _convert_to_vector(initial_weights)) + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) return modelClass(weights, intercept) @@ -134,7 +137,7 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType="none", intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False): """ Train a linear regression model on the given data. @@ -145,16 +148,16 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.0). :param regType: The type of regularizer used for training our model. :Allowed values: - - "l1" for using L1Updater, - - "l2" for using SquaredL2Updater, - - "none" for no regularizer. + - "l1" for using L1 regularization (lasso), + - "l2" for using L2 regularization (ridge), + - None for no regularization - (default: "none") + (default: None) @param intercept: Boolean parameter which indicates the use or not of the augmented representation for @@ -162,11 +165,11 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step, - miniBatchFraction, i, regParam, regType, intercept) + return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), + float(step), float(miniBatchFraction), i, float(regParam), + regType, bool(intercept)) - return _regression_train_wrapper(train, LinearRegressionModel, - data, initialWeights) + return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) class LassoModel(LinearRegressionModelBase): @@ -205,12 +208,13 @@ class LassoModel(LinearRegressionModelBase): class LassoWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, + def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None): """Train a Lasso regression model on the given data.""" def train(rdd, i): - return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i) + return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i) + return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -250,21 +254,21 @@ class RidgeRegressionModel(LinearRegressionModelBase): class RidgeRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, + def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None): """Train a ridge regression model on the given data.""" def train(rdd, i): - return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i) + return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i) - return _regression_train_wrapper(train, RidgeRegressionModel, - data, initialWeights) + return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) def _test(): import doctest from pyspark import SparkContext - globs = globals().copy() + import pyspark.mllib.regression + globs = pyspark.mllib.regression.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 0700f8a8e5a8e..1980f5b03f430 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,6 +22,7 @@ from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint __all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] @@ -107,6 +108,11 @@ def colStats(rdd): """ Computes column-wise summary statistics for the input RDD[Vector]. + :param rdd: an RDD[Vector] for which column-wise summary statistics + are to be computed. + :return: :class:`MultivariateStatisticalSummary` object containing + column-wise summary statistics. + >>> from pyspark.mllib.linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), ... Vectors.dense([4, 5, 0, 3]), @@ -140,6 +146,13 @@ def corr(x, y=None, method=None): to specify the method to be used for single RDD inout. If two RDDs of floats are passed in, a single float is returned. + :param x: an RDD of vector for which the correlation matrix is to be computed, + or an RDD of float of the same cardinality as y when y is specified. + :param y: an RDD of float of the same cardinality as x. + :param method: String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman` + :return: Correlation matrix comparing columns in x. + >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) @@ -242,7 +255,6 @@ def chiSqTest(observed, expected=None): >>> print round(chi.statistic, 4) 21.9958 - >>> from pyspark.mllib.regression import LabeledPoint >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), @@ -257,6 +269,8 @@ def chiSqTest(observed, expected=None): 1.5 """ if isinstance(observed, RDD): + if not isinstance(observed.first(), LabeledPoint): + raise ValueError("observed should be an RDD of LabeledPoint") jmodels = callMLlibFunc("chiSqTest", observed) return [ChiSqTestResult(m) for m in jmodels] diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 9fa4d6f6a2f5f..8332f8e061f48 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,7 +33,8 @@ else: import unittest -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ + DenseMatrix from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics @@ -62,6 +63,7 @@ def _squared_distance(a, b): class VectorTests(PySparkTestCase): def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec))) self.assertEqual(v, nv) @@ -75,6 +77,8 @@ def test_serialize(self): self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) self._test_serialize(DenseVector(pyarray.array('d', range(10)))) self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) def test_dot(self): sv = SparseVector(4, {1: 1, 3: 2}) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5d1a3c0962796..46e253991aa56 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -15,12 +15,16 @@ # limitations under the License. # +from __future__ import absolute_import + +import random + from pyspark import SparkContext, RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint -__all__ = ['DecisionTreeModel', 'DecisionTree'] +__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest'] class DecisionTreeModel(JavaModelWrapper): @@ -51,27 +55,25 @@ def depth(self): return self._java_model.depth() def __repr__(self): - """ Print summary of model. """ + """ summary of model. """ return self._java_model.toString() def toDebugString(self): - """ Print full model. """ + """ full model. """ return self._java_model.toDebugString() class DecisionTree(object): """ - Learning algorithm for a decision tree model - for classification or regression. + Learning algorithm for a decision tree model for classification or regression. EXPERIMENTAL: This is an experimental API. - It will probably be modified for Spark v1.2. - + It will probably be modified in future. """ - @staticmethod - def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32, + @classmethod + def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" @@ -79,8 +81,8 @@ def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBin impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) return DecisionTreeModel(model) - @staticmethod - def trainClassifier(data, numClasses, categoricalFeaturesInfo, + @classmethod + def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ @@ -98,8 +100,8 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child nodes to create - the parent split + :param minInstancesPerNode: Min number of instances required at child + nodes to create the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel @@ -124,16 +126,19 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True + >>> model.predict(array([1.0])) + 1.0 + >>> model.predict(array([0.0])) + 0.0 + >>> rdd = sc.parallelize([[1.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ - return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + return cls._train(data, "classification", numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) - @staticmethod - def trainRegressor(data, categoricalFeaturesInfo, + @classmethod + def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ @@ -150,14 +155,13 @@ def trainRegressor(data, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child nodes to create - the parent split + :param minInstancesPerNode: Min number of instances required at child + nodes to create the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel Example usage: - >>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> from pyspark.mllib.linalg import SparseVector @@ -170,17 +174,213 @@ def trainRegressor(data, categoricalFeaturesInfo, ... ] >>> >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {}) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True - """ - return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {1: 0.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "regression", 0, categoricalFeaturesInfo, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + + +class RandomForestModel(JavaModelWrapper): + """ + Represents a random forest model. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified in future. + """ + def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + if isinstance(x, RDD): + return self.call("predict", x.map(_convert_to_vector)) + + else: + return self.call("predict", _convert_to_vector(x)) + + def numTrees(self): + """ + Get number of trees in forest. + """ + return self.call("numTrees") + + def totalNumNodes(self): + """ + Get total number of nodes, summed over all trees in the forest. + """ + return self.call("totalNumNodes") + + def __repr__(self): + """ Summary of model """ + return self._java_model.toString() + + def toDebugString(self): + """ Full model """ + return self._java_model.toDebugString() + + +class RandomForest(object): + """ + Learning algorithm for a random forest model for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified in future. + """ + + supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") + + @classmethod + def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, + featureSubsetStrategy, impurity, maxDepth, maxBins, seed): + first = data.first() + assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" + if featureSubsetStrategy not in cls.supportedFeatureSubsetStrategies: + raise ValueError("unsupported featureSubsetStrategy: %s" % featureSubsetStrategy) + if seed is None: + seed = random.randint(0, 1 << 30) + model = callMLlibFunc("trainRandomForestModel", data, algo, numClasses, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, + maxDepth, maxBins, seed) + return RandomForestModel(model) + + @classmethod + def trainClassifier(cls, data, numClassesForClassification, categoricalFeaturesInfo, numTrees, + featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, + seed=None): + """ + Method to train a decision tree model for binary or multiclass + classification. + + :param data: Training dataset: RDD of LabeledPoint. Labels should take + values {0, 1, ..., numClasses-1}. + :param numClassesForClassification: number of classes for classification. + :param categoricalFeaturesInfo: Map storing arity of categorical features. + E.g., an entry (n -> k) indicates that feature n is categorical + with k categories indexed from 0: {0, 1, ..., k-1}. + :param numTrees: Number of trees in the random forest. + :param featureSubsetStrategy: Number of features to consider for splits at + each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + :param impurity: Criterion used for information gain calculation. + Supported values: "gini" (recommended) or "entropy". + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; + depth 1 means 1 internal node + 2 leaf nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting features + (default: 100) + :param seed: Random seed for bootstrapping and choosing feature subsets. + :return: RandomForestModel that can be used for prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import RandomForest + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(0.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> model = RandomForest.trainClassifier(sc.parallelize(data), 2, {}, 3, seed=42) + >>> model.numTrees() + 3 + >>> model.totalNumNodes() + 7 + >>> print model, + TreeEnsembleModel classifier with 3 trees + >>> print model.toDebugString(), + TreeEnsembleModel classifier with 3 trees + + Tree 0: + Predict: 1.0 + Tree 1: + If (feature 0 <= 1.0) + Predict: 0.0 + Else (feature 0 > 1.0) + Predict: 1.0 + Tree 2: + If (feature 0 <= 1.0) + Predict: 0.0 + Else (feature 0 > 1.0) + Predict: 1.0 + >>> model.predict([2.0]) + 1.0 + >>> model.predict([0.0]) + 0.0 + >>> rdd = sc.parallelize([[3.0], [1.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "classification", numClassesForClassification, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, + maxDepth, maxBins, seed) + + @classmethod + def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", + impurity="variance", maxDepth=4, maxBins=32, seed=None): + """ + Method to train a decision tree model for regression. + + :param data: Training dataset: RDD of LabeledPoint. Labels are + real numbers. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param numTrees: Number of trees in the random forest. + :param featureSubsetStrategy: Number of features to consider for + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. + :param impurity: Criterion used for information gain calculation. + Supported values: "variance". + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 + leaf node; depth 1 means 1 internal node + 2 leaf nodes. + (default: 4) + :param maxBins: maximum number of bins used for splitting features + (default: 100) + :param seed: Random seed for bootstrapping and choosing feature subsets. + :return: RandomForestModel that can be used for prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import RandomForest + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = RandomForest.trainRegressor(sc.parallelize(sparse_data), {}, 2, seed=42) + >>> model.numTrees() + 2 + >>> model.totalNumNodes() + 4 + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {0: 1.0})) + 0.5 + >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.5] + """ + return cls._train(data, "regression", 0, categoricalFeaturesInfo, numTrees, + featureSubsetStrategy, impurity, maxDepth, maxBins, seed) def _test(): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 96aef8f510fa6..4ed978b45409c 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -161,15 +161,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) - >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect() - >>> type(loaded[0]) == LabeledPoint - True - >>> print examples[0] - (1.1,(3,[0,2],[-1.23,4.56e-07])) - >>> type(examples[1]) == LabeledPoint - True - >>> print examples[1] - (0.0,[1.01,2.02,3.03]) + >>> MLUtils.loadLabeledPoints(sc, tempFile.name).collect() + [LabeledPoint(1.1, (3,[0,2],[-1.23,4.56e-07])), LabeledPoint(0.0, [1.01,2.02,3.03])] """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 08d047402625f..57754776faaa2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -28,7 +28,7 @@ import warnings import heapq import bisect -from random import Random +import random from math import sqrt, log, isinf, isnan from pyspark.accumulators import PStatsParam @@ -38,7 +38,7 @@ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter -from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler +from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ @@ -310,17 +310,43 @@ def distinct(self, numPartitions=None): def sample(self, withReplacement, fraction, seed=None): """ - Return a sampled subset of this RDD (relies on numpy and falls back - on default random generator if numpy is unavailable). + Return a sampled subset of this RDD. + + >>> rdd = sc.parallelize(range(100), 4) + >>> rdd.sample(False, 0.1, 81).count() + 10 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) + def randomSplit(self, weights, seed=None): + """ + Randomly splits this RDD with the provided weights. + + :param weights: weights for splits, will be normalized if they don't sum to 1 + :param seed: random seed + :return: split RDDs in a list + + >>> rdd = sc.parallelize(range(5), 1) + >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17) + >>> rdd1.collect() + [1, 3] + >>> rdd2.collect() + [0, 2, 4] + """ + s = float(sum(weights)) + cweights = [0.0] + for w in weights: + cweights.append(cweights[-1] + w / s) + if seed is None: + seed = random.randint(0, 2 ** 32 - 1) + return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True) + for lb, ub in zip(cweights, cweights[1:])] + # this is ported from scala/spark/RDD.scala def takeSample(self, withReplacement, num, seed=None): """ - Return a fixed-size sampled subset of this RDD (currently requires - numpy). + Return a fixed-size sampled subset of this RDD. >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) @@ -341,7 +367,7 @@ def takeSample(self, withReplacement, num, seed=None): if initialCount == 0: return [] - rand = Random(seed) + rand = random.Random(seed) if (not withReplacement) and num >= initialCount: # shuffle current RDD and return diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index f5c3cfd259a5b..459e1427803cb 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -17,81 +17,48 @@ import sys import random +import math class RDDSamplerBase(object): def __init__(self, withReplacement, seed=None): - try: - import numpy - self._use_numpy = True - except ImportError: - print >> sys.stderr, ( - "NumPy does not appear to be installed. " - "Falling back to default random generator for sampling.") - self._use_numpy = False - - self._seed = seed if seed is not None else random.randint(0, 2 ** 32 - 1) + self._seed = seed if seed is not None else random.randint(0, sys.maxint) self._withReplacement = withReplacement self._random = None - self._split = None - self._rand_initialized = False def initRandomGenerator(self, split): - if self._use_numpy: - import numpy - self._random = numpy.random.RandomState(self._seed ^ split) - else: - self._random = random.Random(self._seed ^ split) + self._random = random.Random(self._seed ^ split) # mixing because the initial seeds are close to each other for _ in xrange(10): self._random.randint(0, 1) - self._split = split - self._rand_initialized = True - - def getUniformSample(self, split): - if not self._rand_initialized or split != self._split: - self.initRandomGenerator(split) - - if self._use_numpy: - return self._random.random_sample() + def getUniformSample(self): + return self._random.random() + + def getPoissonSample(self, mean): + # Using Knuth's algorithm described in + # http://en.wikipedia.org/wiki/Poisson_distribution + if mean < 20.0: + # one exp and k+1 random calls + l = math.exp(-mean) + p = self._random.random() + k = 0 + while p > l: + k += 1 + p *= self._random.random() else: - return self._random.uniform(0.0, 1.0) - - def getPoissonSample(self, split, mean): - if not self._rand_initialized or split != self._split: - self.initRandomGenerator(split) - - if self._use_numpy: - return self._random.poisson(mean) - else: - # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by - # drawing a sequence of numbers delta_j ~ Exp(mean) - num_arrivals = 1 - cur_time = 0.0 - - cur_time += self._random.expovariate(mean) - - if cur_time > 1.0: - return 0 - - while(cur_time <= 1.0): - cur_time += self._random.expovariate(mean) - num_arrivals += 1 + # switch to the log domain, k+1 expovariate (random + log) calls + p = self._random.expovariate(mean) + k = 0 + while p < 1.0: + k += 1 + p += self._random.expovariate(mean) + return k - return (num_arrivals - 1) - - def shuffle(self, vals): - if self._random is None: - self.initRandomGenerator(0) # this should only ever called on the master so - # the split does not matter - - if self._use_numpy: - self._random.shuffle(vals) - else: - self._random.shuffle(vals, self._random.random) + def func(self, split, iterator): + raise NotImplementedError class RDDSampler(RDDSamplerBase): @@ -101,20 +68,35 @@ def __init__(self, withReplacement, fraction, seed=None): self._fraction = fraction def func(self, split, iterator): + self.initRandomGenerator(split) if self._withReplacement: for obj in iterator: # For large datasets, the expected number of occurrences of each element in # a sample with replacement is Poisson(frac). We use that to get a count for # each element. - count = self.getPoissonSample(split, mean=self._fraction) + count = self.getPoissonSample(self._fraction) for _ in range(0, count): yield obj else: for obj in iterator: - if self.getUniformSample(split) <= self._fraction: + if self.getUniformSample() < self._fraction: yield obj +class RDDRangeSampler(RDDSamplerBase): + + def __init__(self, lowerBound, upperBound, seed=None): + RDDSamplerBase.__init__(self, False, seed) + self._lowerBound = lowerBound + self._upperBound = upperBound + + def func(self, split, iterator): + self.initRandomGenerator(split) + for obj in iterator: + if self._lowerBound <= self.getUniformSample() < self._upperBound: + yield obj + + class RDDStratifiedSampler(RDDSamplerBase): def __init__(self, withReplacement, fractions, seed=None): @@ -122,15 +104,16 @@ def __init__(self, withReplacement, fractions, seed=None): self._fractions = fractions def func(self, split, iterator): + self.initRandomGenerator(split) if self._withReplacement: for key, val in iterator: # For large datasets, the expected number of occurrences of each element in # a sample with replacement is Poisson(frac). We use that to get a count for # each element. - count = self.getPoissonSample(split, mean=self._fractions[key]) + count = self.getPoissonSample(self._fractions[key]) for _ in range(0, count): yield key, val else: for key, val in iterator: - if self.getUniformSample(split) <= self._fractions[key]: + if self.getUniformSample() < self._fractions[key]: yield key, val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d597cbf94e1b1..33aa55f7f1429 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -133,6 +133,8 @@ def load_stream(self, stream): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) + if len(serialized) > (1 << 31): + raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) if self._only_write_strings: stream.write(str(serialized)) @@ -450,9 +452,9 @@ class CompressedSerializer(FramedSerializer): """ Compress the serialized data """ - def __init__(self, serializer): FramedSerializer.__init__(self) + assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer" self.serializer = serializer def dumps(self, obj): @@ -517,3 +519,8 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) + + +if __name__ == '__main__': + import doctest + doctest.testmod() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 5931e923c2e36..10a7ccd502000 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -478,13 +478,21 @@ def _get_path(self, n): os.makedirs(d) return os.path.join(d, str(n)) + def _next_limit(self): + """ + Return the next memory limit. If the memory is not released + after spilling, it will dump the data only when the used memory + starts to increase. + """ + return max(self.memory_limit, get_used_memory() * 1.05) + def sorted(self, iterator, key=None, reverse=False): """ Sort the elements in iterator, do external sort when the memory goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch = 100 + batch, limit = 100, self._next_limit() chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -504,6 +512,7 @@ def sorted(self, iterator, key=None, reverse=False): chunks.append(self.serializer.load_stream(open(path))) current_chunk = [] gc.collect() + limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e5d62a466cab6..ae288471b0e51 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -45,7 +45,7 @@ from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer + CloudPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -1178,7 +1178,7 @@ class Row(tuple): def asDict(self): """ Return as a dict """ - return dict(zip(self.__FIELDS__, self)) + return dict((n, getattr(self, n)) for n in self.__FIELDS__) def __repr__(self): # call collect __repr__ for nested objects @@ -1870,6 +1870,21 @@ def limit(self, num): rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() return SchemaRDD(rdd, self.sql_ctx) + def toJSON(self, use_unicode=False): + """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") + >>> srdd2 = sqlCtx.sql( "SELECT * from table1") + >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + True + >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + True + """ + rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 2f53fbd27b17a..d48f3598e33b2 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -142,8 +142,8 @@ def getOrCreate(cls, checkpointPath, setupFunc): recreated from the checkpoint data. If the data does not exist, then the provided setupFunc will be used to create a JavaStreamingContext. - @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program - @param setupFunc Function to create a new JavaStreamingContext and setup DStreams + @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program + @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams """ # TODO: support checkpoint in HDFS if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 491e445a216bf..32645778c2b8f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -32,6 +32,7 @@ import zipfile import random import threading +import hashlib if sys.version_info[:2] <= (2, 6): try: @@ -47,7 +48,7 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer + CloudPickleSerializer, CompressedSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -236,6 +237,17 @@ def foo(): self.assertTrue("exit" in foo.func_code.co_names) ser.dumps(foo) + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + from StringIO import StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + ser.dump_stream(range(1000), io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io))) + class PySparkTestCase(unittest.TestCase): @@ -440,7 +452,7 @@ def test_sampling_default_seed(self): subset = data.takeSample(False, 10) self.assertEqual(len(subset), 10) - def testAggregateByKey(self): + def test_aggregate_by_key(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) def seqOp(x, y): @@ -478,6 +490,32 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = range(1 << 15) + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] @@ -755,7 +793,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name) + shutil.rmtree(cls.tempdir.name, ignore_errors=True) def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -882,8 +920,9 @@ def test_convert_row_to_dict(self): rdd = self.sc.parallelize([row]) srdd = self.sqlCtx.inferSchema(rdd) srdd.registerTempTable("test") - row = self.sqlCtx.sql("select l[0].a AS la from test").first() - self.assertEqual(1, row.asDict()["la"]) + row = self.sqlCtx.sql("select l, d from test").first() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) def test_infer_schema_with_udt(self): from pyspark.tests import ExamplePoint, ExamplePointUDT diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2bdccb5e93f09..7e5343c973dc5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,8 +30,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - CompressedSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -78,12 +77,11 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) - ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: - value = ser._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) + path = utf8_deserializer.loads(infile) + _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) diff --git a/python/run-tests b/python/run-tests index a4f0cac059ff3..9ee19ed6e6b26 100755 --- a/python/run-tests +++ b/python/run-tests @@ -56,7 +56,7 @@ function run_core_tests() { run_test "pyspark/conf.py" PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py" + run_test "pyspark/serializers.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } @@ -72,7 +72,7 @@ function run_mllib_tests() { run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/feature.py" run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/random.py" + run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/stat.py" diff --git a/repl/pom.xml b/repl/pom.xml index af528c8914335..9b2290429fee5 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -35,9 +35,16 @@ repl /usr/share/spark root + scala-2.10/src/main/scala + scala-2.10/src/test/scala + + ${jline.groupid} + jline + ${jline.version} + org.apache.spark spark-core_${scala.binary.version} @@ -75,11 +82,6 @@ scala-reflect ${scala.version} - - org.scala-lang - jline - ${scala.version} - org.slf4j jul-to-slf4j @@ -122,6 +124,51 @@ + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + ${extra.source.dir} + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + ${extra.testsource.dir} + + + + + + + + scala-2.11 + + scala-2.11 + + + scala-2.11/src/main/scala + scala-2.11/src/test/scala + + + diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/Main.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala similarity index 95% rename from repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 7667a9c11979e..da4286c5e4874 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -121,11 +121,14 @@ trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext(); + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc.") + _sc + } """) command("import org.apache.spark.SparkContext._") } - echo("Spark context available as sc.") } // code to be executed only after the interpreter is initialized diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkImports.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala similarity index 100% rename from repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala rename to repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala new file mode 100644 index 0000000000000..5e93a71995072 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import org.apache.spark.util.Utils +import org.apache.spark._ + +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter.SparkILoop + +object Main extends Logging { + + val conf = new SparkConf() + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = conf.get("spark.repl.classdir", tmp) + val outputDir = Utils.createTempDir(rootDir) + val s = new Settings() + s.processArguments(List("-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) + val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + var sparkContext: SparkContext = _ + var interp = new SparkILoop // this is a public var because tests reset it. + + def main(args: Array[String]) { + if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + // Start the classServer and store its URI in a spark system property + // (which will be passed to executors so that they can connect to it) + classServer.start() + interp.process(s) // Repl starts and goes in loop of R.E.P.L + classServer.stop() + Option(sparkContext).map(_.stop) + } + + + def getAddedJars: Array[String] = { + val envJars = sys.env.get("ADD_JARS") + val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } + val jars = propJars.orElse(envJars).getOrElse("") + Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) + } + + def createSparkContext(): SparkContext = { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + val jars = getAddedJars + val conf = new SparkConf() + .setMaster(getMaster) + .setAppName("Spark shell") + .setJars(jars) + .set("spark.repl.class.uri", classServer.uri) + logInfo("Spark class server started at " + classServer.uri) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } + sparkContext = new SparkContext(conf) + logInfo("Created spark context..") + sparkContext + } + + private def getMaster: String = { + val master = { + val envMaster = sys.env.get("MASTER") + val propMaster = sys.props.get("spark.master") + propMaster.orElse(envMaster).getOrElse("local[*]") + } + master + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala new file mode 100644 index 0000000000000..8e519fa67f649 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -0,0 +1,86 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package scala.tools.nsc +package interpreter + +import scala.tools.nsc.ast.parser.Tokens.EOF + +trait SparkExprTyper { + val repl: SparkIMain + + import repl._ + import global.{ reporter => _, Import => _, _ } + import naming.freshInternalVarName + + def symbolOfLine(code: String): Symbol = { + def asExpr(): Symbol = { + val name = freshInternalVarName() + // Typing it with a lazy val would give us the right type, but runs + // into compiler bugs with things like existentials, so we compile it + // behind a def and strip the NullaryMethodType which wraps the expr. + val line = "def " + name + " = " + code + + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + // drop NullaryMethodType + sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) + case _ => NoSymbol + } + } + def asDefn(): Symbol = { + val old = repl.definedSymbolList.toSet + + interpretSynthetic(code) match { + case IR.Success => + repl.definedSymbolList filterNot old match { + case Nil => NoSymbol + case sym :: Nil => sym + case syms => NoSymbol.newOverloaded(NoPrefix, syms) + } + case _ => NoSymbol + } + } + def asError(): Symbol = { + interpretSynthetic(code) + NoSymbol + } + beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() + } + + private var typeOfExpressionDepth = 0 + def typeOfExpression(expr: String, silent: Boolean = true): Type = { + if (typeOfExpressionDepth > 2) { + repldbg("Terminating typeOfExpression recursion for expression: " + expr) + return NoType + } + typeOfExpressionDepth += 1 + // Don't presently have a good way to suppress undesirable success output + // while letting errors through, so it is first trying it silently: if there + // is an error, and errors are desired, then it re-evaluates non-silently + // to induce the error message. + try beSilentDuring(symbolOfLine(expr).tpe) match { + case NoType if !silent => symbolOfLine(expr).tpe // generate error + case tpe => tpe + } + finally typeOfExpressionDepth -= 1 + } + + // This only works for proper types. + def typeOfTypeString(typeString: String): Type = { + def asProperType(): Option[Type] = { + val name = freshInternalVarName() + val line = "def %s: %s = ???" format (name, typeString) + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + Some(sym0.asMethod.returnType) + case _ => None + } + } + beSilentDuring(asProperType()) getOrElse NoType + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala new file mode 100644 index 0000000000000..250727305970d --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -0,0 +1,969 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Alexander Spoon + */ + +package scala +package tools.nsc +package interpreter + +import scala.language.{ implicitConversions, existentials } +import scala.annotation.tailrec +import Predef.{ println => _, _ } +import interpreter.session._ +import StdReplTags._ +import scala.reflect.api.{Mirror, Universe, TypeCreator} +import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName } +import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import scala.reflect.{ClassTag, classTag} +import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader } +import ScalaClassLoader._ +import scala.reflect.io.{ File, Directory } +import scala.tools.util._ +import scala.collection.generic.Clearable +import scala.concurrent.{ ExecutionContext, Await, Future, future } +import ExecutionContext.Implicits._ +import java.io.{ BufferedReader, FileReader } + +/** The Scala interactive shell. It provides a read-eval-print loop + * around the Interpreter class. + * After instantiation, clients should call the main() method. + * + * If no in0 is specified, then input will come from the console, and + * the class will attempt to provide input editing feature such as + * input history. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + * @version 1.2 + */ +class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) + extends AnyRef + with LoopCommands +{ + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) +// +// @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp +// @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i + + var in: InteractiveReader = _ // the input stream from which commands come + var settings: Settings = _ + var intp: SparkIMain = _ + + var globalFuture: Future[Boolean] = _ + + protected def asyncMessage(msg: String) { + if (isReplInfo || isReplPower) + echoAndRefresh(msg) + } + + def initializeSpark() { + intp.beQuietDuring { + command( """ + @transient val sc = { + val _sc = org.apache.spark.repl.Main.createSparkContext() + println("Spark context available as sc.") + _sc + } + """) + command("import org.apache.spark.SparkContext._") + } + } + + /** Print a welcome message */ + def printWelcome() { + import org.apache.spark.SPARK_VERSION + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + override def echoCommandMessage(msg: String) { + intp.reporter printUntruncatedMessage msg + } + + // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) + def history = in.history + + // classpath entries added via :cp + var addedClasspath: String = "" + + /** A reverse list of commands to replay if the user requests a :replay */ + var replayCommandStack: List[String] = Nil + + /** A list of commands to replay if the user requests a :replay */ + def replayCommands = replayCommandStack.reverse + + /** Record a command for replay should the user request a :replay */ + def addReplay(cmd: String) = replayCommandStack ::= cmd + + def savingReplayStack[T](body: => T): T = { + val saved = replayCommandStack + try body + finally replayCommandStack = saved + } + def savingReader[T](body: => T): T = { + val saved = in + try body + finally in = saved + } + + /** Close the interpreter and set the var to null. */ + def closeInterpreter() { + if (intp ne null) { + intp.close() + intp = null + } + } + + class SparkILoopInterpreter extends SparkIMain(settings, out) { + outer => + + override lazy val formatting = new Formatting { + def prompt = SparkILoop.this.prompt + } + override protected def parentClassLoader = + settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader ) + } + + /** Create a new interpreter. */ + def createInterpreter() { + if (addedClasspath != "") + settings.classpath append addedClasspath + + intp = new SparkILoopInterpreter + } + + /** print a friendly help message */ + def helpCommand(line: String): Result = { + if (line == "") helpSummary() + else uniqueCommand(line) match { + case Some(lc) => echo("\n" + lc.help) + case _ => ambiguousError(line) + } + } + private def helpSummary() = { + val usageWidth = commands map (_.usageMsg.length) max + val formatStr = "%-" + usageWidth + "s %s" + + echo("All commands can be abbreviated, e.g. :he instead of :help.") + + commands foreach { cmd => + echo(formatStr.format(cmd.usageMsg, cmd.help)) + } + } + private def ambiguousError(cmd: String): Result = { + matchingCommands(cmd) match { + case Nil => echo(cmd + ": no such command. Type :help for help.") + case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") + } + Result(keepRunning = true, None) + } + private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) + private def uniqueCommand(cmd: String): Option[LoopCommand] = { + // this lets us add commands willy-nilly and only requires enough command to disambiguate + matchingCommands(cmd) match { + case List(x) => Some(x) + // exact match OK even if otherwise appears ambiguous + case xs => xs find (_.name == cmd) + } + } + + /** Show the history */ + lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + override def usage = "[num]" + def defaultLines = 20 + + def apply(line: String): Result = { + if (history eq NoHistory) + return "No history available." + + val xs = words(line) + val current = history.index + val count = try xs.head.toInt catch { case _: Exception => defaultLines } + val lines = history.asStrings takeRight count + val offset = current - lines.size + 1 + + for ((line, index) <- lines.zipWithIndex) + echo("%3d %s".format(index + offset, line)) + } + } + + // When you know you are most likely breaking into the middle + // of a line being typed. This softens the blow. + protected def echoAndRefresh(msg: String) = { + echo("\n" + msg) + in.redrawLine() + } + protected def echo(msg: String) = { + out println msg + out.flush() + } + + /** Search the history */ + def searchHistory(_cmdline: String) { + val cmdline = _cmdline.toLowerCase + val offset = history.index - history.size + 1 + + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) + echo("%d %s".format(index + offset, line)) + } + + private val currentPrompt = Properties.shellPromptString + + /** Prompt to print when awaiting input */ + def prompt = currentPrompt + + import LoopCommand.{ cmd, nullary } + + /** Standard commands **/ + lazy val standardCommands = List( + cmd("cp", "", "add a jar or directory to the classpath", addClasspath), + cmd("edit", "|", "edit history", editCommand), + cmd("help", "[command]", "print this summary or command-specific help", helpCommand), + historyCommand, + cmd("h?", "", "search the history", searchHistory), + cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), + //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand), + cmd("javap", "", "disassemble a file or class name", javapCommand), + cmd("line", "|", "place line(s) at the end of history", lineCommand), + cmd("load", "", "interpret lines in a file", loadCommand), + cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand), + // nullary("power", "enable power user mode", powerCmd), + nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)), + nullary("replay", "reset execution and replay all previous commands", replay), + nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), + cmd("save", "", "save replayable session to a file", saveCommand), + shCommand, + cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings), + nullary("silent", "disable/enable automatic printing of results", verbosity), +// cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), +// cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand), + nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) + ) + + /** Power user commands */ +// lazy val powerCommands: List[LoopCommand] = List( +// cmd("phase", "", "set the implicit phase for power commands", phaseCommand) +// ) + + private def importsCommand(line: String): Result = { + val tokens = words(line) + val handlers = intp.languageWildcardHandlers ++ intp.importHandlers + + handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { + case (handler, idx) => + val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) + val imps = handler.implicitSymbols + val found = tokens filter (handler importsSymbolNamed _) + val typeMsg = if (types.isEmpty) "" else types.size + " types" + val termMsg = if (terms.isEmpty) "" else terms.size + " terms" + val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" + val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") + val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") + + intp.reporter.printMessage("%2d) %-30s %s%s".format( + idx + 1, + handler.importString, + statsMsg, + foundMsg + )) + } + } + + private def findToolsJar() = PathResolver.SupplementalLocations.platformTools + + private def addToolsJarToLoader() = { + val cl = findToolsJar() match { + case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) + case _ => intp.classLoader + } + if (Javap.isAvailable(cl)) { + repldbg(":javap available.") + cl + } + else { + repldbg(":javap unavailable: no tools.jar at " + jdkHome) + intp.classLoader + } + } +// +// protected def newJavap() = +// JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp)) +// +// private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) + + // Still todo: modules. +// private def typeCommand(line0: String): Result = { +// line0.trim match { +// case "" => ":type [-v] " +// case s => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") +// } +// } + +// private def kindCommand(expr: String): Result = { +// expr.trim match { +// case "" => ":kind [-v] " +// case s => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") +// } +// } + + private def warningsCommand(): Result = { + if (intp.lastWarnings.isEmpty) + "Can't find any cached warnings." + else + intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } + } + + private def changeSettings(args: String): Result = { + def showSettings() = { + for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString) + } + def updateSettings() = { + // put aside +flag options + val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+")) + val tmps = new Settings + val (ok, leftover) = tmps.processArguments(rest, processAll = true) + if (!ok) echo("Bad settings request.") + else if (leftover.nonEmpty) echo("Unprocessed settings.") + else { + // boolean flags set-by-user on tmp copy should be off, not on + val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting]) + val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg)) + // update non-flags + settings.processArguments(nonbools, processAll = true) + // also snag multi-value options for clearing, e.g. -Ylog: and -language: + for { + s <- settings.userSetSettings + if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting] + if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init)) + } s match { + case c: Clearable => c.clear() + case _ => + } + def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = { + for (b <- bs) + settings.lookupSetting(name(b)) match { + case Some(s) => + if (s.isInstanceOf[Settings#BooleanSetting]) setter(s) + else echo(s"Not a boolean flag: $b") + case _ => + echo(s"Not an option: $b") + } + } + update(minuses, identity, _.tryToSetFromPropertyValue("false")) // turn off + update(pluses, "-" + _.drop(1), _.tryToSet(Nil)) // turn on + } + } + if (args.isEmpty) showSettings() else updateSettings() + } + + private def javapCommand(line: String): Result = { +// if (javap == null) +// ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) +// else if (line == "") +// ":javap [-lcsvp] [path1 path2 ...]" +// else +// javap(words(line)) foreach { res => +// if (res.isError) return "Failed: " + res.value +// else res.show() +// } + } + + private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent" + + private def phaseCommand(name: String): Result = { +// val phased: Phased = power.phased +// import phased.NoPhaseName +// +// if (name == "clear") { +// phased.set(NoPhaseName) +// intp.clearExecutionWrapper() +// "Cleared active phase." +// } +// else if (name == "") phased.get match { +// case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" +// case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) +// } +// else { +// val what = phased.parse(name) +// if (what.isEmpty || !phased.set(what)) +// "'" + name + "' does not appear to represent a valid phase." +// else { +// intp.setExecutionWrapper(pathToPhaseWrapper) +// val activeMessage = +// if (what.toString.length == name.length) "" + what +// else "%s (%s)".format(what, name) +// +// "Active phase is now: " + activeMessage +// } +// } + } + + /** Available commands */ + def commands: List[LoopCommand] = standardCommands ++ ( + // if (isReplPower) + // powerCommands + // else + Nil + ) + + val replayQuestionMessage = + """|That entry seems to have slain the compiler. Shall I replay + |your session? I can re-run each line except the last one. + |[y/n] + """.trim.stripMargin + + private val crashRecovery: PartialFunction[Throwable, Boolean] = { + case ex: Throwable => + val (err, explain) = ( + if (intp.isInitializeComplete) + (intp.global.throwableAsString(ex), "") + else + (ex.getMessage, "The compiler did not initialize.\n") + ) + echo(err) + + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("\nUnrecoverable error.") + throw ex + case _ => + def fn(): Boolean = + try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + catch { case _: RuntimeException => false } + + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + true + } + + // return false if repl should exit + def processLine(line: String): Boolean = { + import scala.concurrent.duration._ + Await.ready(globalFuture, 60.seconds) + + (line ne null) && (command(line) match { + case Result(false, _) => false + case Result(_, Some(line)) => addReplay(line) ; true + case _ => true + }) + } + + private def readOneLine() = { + out.flush() + in readLine prompt + } + + /** The main read-eval-print loop for the repl. It calls + * command() for each line of input, and stops when + * command() returns false. + */ + @tailrec final def loop() { + if ( try processLine(readOneLine()) catch crashRecovery ) + loop() + } + + /** interpret all lines from a specified file */ + def interpretAllFrom(file: File) { + savingReader { + savingReplayStack { + file applyReader { reader => + in = SimpleReader(reader, out, interactive = false) + echo("Loading " + file + "...") + loop() + } + } + } + } + + /** create a new interpreter and replay the given commands */ + def replay() { + reset() + if (replayCommandStack.isEmpty) + echo("Nothing to replay.") + else for (cmd <- replayCommands) { + echo("Replaying: " + cmd) // flush because maybe cmd will have its own output + command(cmd) + echo("") + } + } + def resetCommand() { + echo("Resetting interpreter state.") + if (replayCommandStack.nonEmpty) { + echo("Forgetting this session history:\n") + replayCommands foreach echo + echo("") + replayCommandStack = Nil + } + if (intp.namedDefinedTerms.nonEmpty) + echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) + if (intp.definedTypes.nonEmpty) + echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) + + reset() + } + def reset() { + intp.reset() + unleashAndSetPhase() + } + + def lineCommand(what: String): Result = editCommand(what, None) + + // :edit id or :edit line + def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR")) + + def editCommand(what: String, editor: Option[String]): Result = { + def diagnose(code: String) = { + echo("The edited code is incomplete!\n") + val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") + if (errless) echo("The compiler reports no errors.") + } + def historicize(text: String) = history match { + case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true + case _ => false + } + def edit(text: String): Result = editor match { + case Some(ed) => + val tmp = File.makeTemp() + tmp.writeAll(text) + try { + val pr = new ProcessResult(s"$ed ${tmp.path}") + pr.exitCode match { + case 0 => + tmp.safeSlurp() match { + case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.") + case Some(edited) => + echo(edited.lines map ("+" + _) mkString "\n") + val res = intp interpret edited + if (res == IR.Incomplete) diagnose(edited) + else { + historicize(edited) + Result(lineToRecord = Some(edited), keepRunning = true) + } + case None => echo("Can't read edited text. Did you delete it?") + } + case x => echo(s"Error exit from $ed ($x), ignoring") + } + } finally { + tmp.delete() + } + case None => + if (historicize(text)) echo("Placing text in recent history.") + else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text") + } + + // if what is a number, use it as a line number or range in history + def isNum = what forall (c => c.isDigit || c == '-' || c == '+') + // except that "-" means last value + def isLast = (what == "-") + if (isLast || !isNum) { + val name = if (isLast) intp.mostRecentVar else what + val sym = intp.symbolOfIdent(name) + intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match { + case Some(req) => edit(req.line) + case None => echo(s"No symbol in scope: $what") + } + } else try { + val s = what + // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur) + val (start, len) = + if ((s indexOf '+') > 0) { + val (a,b) = s splitAt (s indexOf '+') + (a.toInt, b.drop(1).toInt) + } else { + (s indexOf '-') match { + case -1 => (s.toInt, 1) + case 0 => val n = s.drop(1).toInt ; (history.index - n, n) + case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n) + case i => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n) + } + } + import scala.collection.JavaConverters._ + val index = (start - 1) max 0 + val text = history match { + case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n" + case _ => history.asStrings.slice(index, index + len) mkString "\n" + } + edit(text) + } catch { + case _: NumberFormatException => echo(s"Bad range '$what'") + echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)") + } + } + + /** fork a shell and run a command */ + lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + override def usage = "" + def apply(line: String): Result = line match { + case "" => showUsage() + case _ => + val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})" + intp interpret toRun + () + } + } + + def withFile[A](filename: String)(action: File => A): Option[A] = { + val res = Some(File(filename)) filter (_.exists) map action + if (res.isEmpty) echo("That file does not exist") // courtesy side-effect + res + } + + def loadCommand(arg: String) = { + var shouldReplay: Option[String] = None + withFile(arg)(f => { + interpretAllFrom(f) + shouldReplay = Some(":load " + arg) + }) + Result(keepRunning = true, shouldReplay) + } + + def saveCommand(filename: String): Result = ( + if (filename.isEmpty) echo("File name is required.") + else if (replayCommandStack.isEmpty) echo("No replay commands in session") + else File(filename).printlnAll(replayCommands: _*) + ) + + def addClasspath(arg: String): Unit = { + val f = File(arg).normalize + if (f.exists) { + addedClasspath = ClassPath.join(addedClasspath, f.path) + val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) + replay() + } + else echo("The path '" + f + "' doesn't seem to exist.") + } + + def powerCmd(): Result = { + if (isReplPower) "Already in power mode." + else enablePowerMode(isDuringInit = false) + } + def enablePowerMode(isDuringInit: Boolean) = { + replProps.power setValue true + unleashAndSetPhase() + // asyncEcho(isDuringInit, power.banner) + } + private def unleashAndSetPhase() { + if (isReplPower) { + // power.unleash() + // Set the phase to "typer" + // intp beSilentDuring phaseCommand("typer") + } + } + + def asyncEcho(async: Boolean, msg: => String) { + if (async) asyncMessage(msg) + else echo(msg) + } + + def verbosity() = { + val old = intp.printResults + intp.printResults = !old + echo("Switched " + (if (old) "off" else "on") + " result printing.") + } + + /** Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. */ + def command(line: String): Result = { + if (line startsWith ":") { + val cmd = line.tail takeWhile (x => !x.isWhitespace) + uniqueCommand(cmd) match { + case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) + case _ => ambiguousError(cmd) + } + } + else if (intp.global == null) Result(keepRunning = false, None) // Notice failure to create compiler + else Result(keepRunning = true, interpretStartingWith(line)) + } + + private def readWhile(cond: String => Boolean) = { + Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) + } + + def pasteCommand(arg: String): Result = { + var shouldReplay: Option[String] = None + def result = Result(keepRunning = true, shouldReplay) + val (raw, file) = + if (arg.isEmpty) (false, None) + else { + val r = """(-raw)?(\s+)?([^\-]\S*)?""".r + arg match { + case r(flag, sep, name) => + if (flag != null && name != null && sep == null) + echo(s"""I assume you mean "$flag $name"?""") + (flag != null, Option(name)) + case _ => + echo("usage: :paste -raw file") + return result + } + } + val code = file match { + case Some(name) => + withFile(name)(f => { + shouldReplay = Some(s":paste $arg") + val s = f.slurp.trim + if (s.isEmpty) echo(s"File contains no code: $f") + else echo(s"Pasting file $f...") + s + }) getOrElse "" + case None => + echo("// Entering paste mode (ctrl-D to finish)\n") + val text = (readWhile(_ => true) mkString "\n").trim + if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n") + else echo("\n// Exiting paste mode, now interpreting.\n") + text + } + def interpretCode() = { + val res = intp interpret code + // if input is incomplete, let the compiler try to say why + if (res == IR.Incomplete) { + echo("The pasted code is incomplete!\n") + // Remembrance of Things Pasted in an object + val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") + if (errless) echo("...but compilation found no error? Good luck with that.") + } + } + def compileCode() = { + val errless = intp compileSources new BatchSourceFile("", code) + if (!errless) echo("There were compilation errors!") + } + if (code.nonEmpty) { + if (raw) compileCode() else interpretCode() + } + result + } + + private object paste extends Pasted { + val ContinueString = " | " + val PromptString = "scala> " + + def interpret(line: String): Unit = { + echo(line.trim) + intp interpret line + echo("") + } + + def transcript(start: String) = { + echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") + apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) + } + } + import paste.{ ContinueString, PromptString } + + /** Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ + def interpretStartingWith(code: String): Option[String] = { + // signal completion non-completion input has been received + in.completion.resetVerbosity() + + def reallyInterpret = { + val reallyResult = intp.interpret(code) + (reallyResult, reallyResult match { + case IR.Error => None + case IR.Success => Some(code) + case IR.Incomplete => + if (in.interactive && code.endsWith("\n\n")) { + echo("You typed two blank lines. Starting a new command.") + None + } + else in.readLine(ContinueString) match { + case null => + // we know compilation is going to fail since we're at EOF and the + // parser thinks the input is still incomplete, but since this is + // a file being read non-interactively we want to fail. So we send + // it straight to the compiler for the nice error message. + intp.compileString(code) + None + + case line => interpretStartingWith(code + "\n" + line) + } + }) + } + + /** Here we place ourselves between the user and the interpreter and examine + * the input they are ostensibly submitting. We intervene in several cases: + * + * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. + * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation + * on the previous result. + * 3) If the Completion object's execute returns Some(_), we inject that value + * and avoid the interpreter, as it's likely not valid scala code. + */ + if (code == "") None + else if (!paste.running && code.trim.startsWith(PromptString)) { + paste.transcript(code) + None + } + else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { + interpretStartingWith(intp.mostRecentVar + code) + } + else if (code.trim startsWith "//") { + // line comment, do nothing + None + } + else + reallyInterpret._2 + } + + // runs :load `file` on any files passed via -i + def loadFiles(settings: Settings) = settings match { + case settings: GenericRunnerSettings => + for (filename <- settings.loadfiles.value) { + val cmd = ":load " + filename + command(cmd) + addReplay(cmd) + echo("") + } + case _ => + } + + /** Tries to create a JLineReader, falling back to SimpleReader: + * unless settings or properties are such that it should start + * with SimpleReader. + */ + def chooseReader(settings: Settings): InteractiveReader = { + if (settings.Xnojline || Properties.isEmacsShell) + SimpleReader() + else try new JLineReader( + if (settings.noCompletion) NoCompletion + else new SparkJLineCompletion(intp) + ) + catch { + case ex @ (_: Exception | _: NoClassDefFoundError) => + echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.") + SimpleReader() + } + } + protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = + u.TypeTag[T]( + m, + new TypeCreator { + def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type = + m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] + }) + + private def loopPostInit() { + // Bind intp somewhere out of the regular namespace where + // we can get at it in generated code. + intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain])) + // Auto-run code via some setting. + ( replProps.replAutorunCode.option + flatMap (f => io.File(f).safeSlurp()) + foreach (intp quietRun _) + ) + // classloader and power mode setup + intp.setContextClassLoader() + if (isReplPower) { + // replProps.power setValue true + // unleashAndSetPhase() + // asyncMessage(power.banner) + } + // SI-7418 Now, and only now, can we enable TAB completion. + in match { + case x: JLineReader => x.consoleReader.postInit + case _ => + } + } + def process(settings: Settings): Boolean = savingContextLoader { + this.settings = settings + createInterpreter() + + // sets in to some kind of reader depending on environmental cues + in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) + globalFuture = future { + intp.initializeSynchronous() + loopPostInit() + !intp.reporter.hasErrors + } + import scala.concurrent.duration._ + Await.ready(globalFuture, 10 seconds) + printWelcome() + initializeSpark() + loadFiles(settings) + + try loop() + catch AbstractOrMissingHandler() + finally closeInterpreter() + + true + } + + @deprecated("Use `process` instead", "2.9.0") + def main(settings: Settings): Unit = process(settings) //used by sbt +} + +object SparkILoop { + implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp + + // Designed primarily for use by test code: take a String with a + // bunch of code, and prints out a transcript of what it would look + // like if you'd just typed it into the repl. + def runForTranscript(code: String, settings: Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { + override def write(str: String) = { + // completely skip continuation lines + if (str forall (ch => ch.isWhitespace || ch == '|')) () + else super.write(str) + } + } + val input = new BufferedReader(new StringReader(code.trim + "\n")) { + override def readLine(): String = { + val s = super.readLine() + // helping out by printing the line being interpreted. + if (s != null) + output.println(s) + s + } + } + val repl = new SparkILoop(input, output) + if (settings.classpath.isDefault) + settings.classpath.value = sys.props("java.class.path") + + repl process settings + } + } + } + + /** Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ + def run(code: String, sets: Settings = new Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) + + if (sets.classpath.isDefault) + sets.classpath.value = sys.props("java.class.path") + + repl process sets + } + } + } + def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala new file mode 100644 index 0000000000000..1bb62c84abddc --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -0,0 +1,1319 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Martin Odersky + */ + +package scala +package tools.nsc +package interpreter + +import PartialFunction.cond +import scala.language.implicitConversions +import scala.beans.BeanProperty +import scala.collection.mutable +import scala.concurrent.{ Future, ExecutionContext } +import scala.reflect.runtime.{ universe => ru } +import scala.reflect.{ ClassTag, classTag } +import scala.reflect.internal.util.{ BatchSourceFile, SourceFile } +import scala.tools.util.PathResolver +import scala.tools.nsc.io.AbstractFile +import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings } +import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps } +import scala.tools.nsc.util.Exceptional.unwrap +import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable} + +/** An interpreter for Scala code. + * + * The main public entry points are compile(), interpret(), and bind(). + * The compile() method loads a complete Scala file. The interpret() method + * executes one line of Scala code at the request of the user. The bind() + * method binds an object to a variable that can then be used by later + * interpreted code. + * + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports members called "$eval" and "$print". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + */ +class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings, + protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports { + imain => + + setBindings(createBindings, ScriptContext.ENGINE_SCOPE) + object replOutput extends ReplOutput(settings.Yreploutdir) { } + + @deprecated("Use replOutput.dir instead", "2.11.0") + def virtualDirectory = replOutput.dir + // Used in a test case. + def showDirectory() = replOutput.show(out) + + private[nsc] var printResults = true // whether to print result lines + private[nsc] var totalSilence = false // whether to print anything + private var _initializeComplete = false // compiler is initialized + private var _isInitialized: Future[Boolean] = null // set up initialization future + private var bindExceptions = true // whether to bind the lastException variable + private var _executionWrapper = "" // code to be wrapped around all lines + + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private var _classLoader: util.AbstractFileClassLoader = null // active classloader + private val _compiler: ReplGlobal = newCompiler(settings, reporter) // our private compiler + + def compilerClasspath: Seq[java.net.URL] = ( + if (isInitializeComplete) global.classPath.asURLs + else new PathResolver(settings).result.asURLs // the compiler's classpath + ) + def settings = initialSettings + // Run the code body with the given boolean settings flipped to true. + def withoutWarnings[T](body: => T): T = beQuietDuring { + val saved = settings.nowarn.value + if (!saved) + settings.nowarn.value = true + + try body + finally if (!saved) settings.nowarn.value = false + } + + /** construct an interpreter that reports to Console */ + def this(settings: Settings, out: JPrintWriter) = this(null, settings, out) + def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this(factory: ScriptEngineFactory) = this(factory, new Settings()) + def this() = this(new Settings()) + + lazy val formatting: Formatting = new Formatting { + val prompt = Properties.shellPromptString + } + lazy val reporter: SparkReplReporter = new SparkReplReporter(this) + + import formatting._ + import reporter.{ printMessage, printUntruncatedMessage } + + // This exists mostly because using the reporter too early leads to deadlock. + private def echo(msg: String) { Console println msg } + private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) + private def _initialize() = { + try { + // if this crashes, REPL will hang its head in shame + val run = new _compiler.Run() + assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") + run compileSources _initSources + _initializeComplete = true + true + } + catch AbstractOrMissingHandler() + } + private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" + private val logScope = scala.sys.props contains "scala.repl.scope" + private def scopelog(msg: String) = if (logScope) Console.err.println(msg) + + // argument is a thunk to execute after init is done + def initialize(postInitSignal: => Unit) { + synchronized { + if (_isInitialized == null) { + _isInitialized = + Future(try _initialize() finally postInitSignal)(ExecutionContext.global) + } + } + } + def initializeSynchronous(): Unit = { + if (!isInitializeComplete) { + _initialize() + assert(global != null, global) + } + } + def isInitializeComplete = _initializeComplete + + lazy val global: Global = { + if (!isInitializeComplete) _initialize() + _compiler + } + + import global._ + import definitions.{ ObjectClass, termMember, dropNullaryMethod} + + lazy val runtimeMirror = ru.runtimeMirror(classLoader) + + private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol } + + def getClassIfDefined(path: String) = ( + noFatal(runtimeMirror staticClass path) + orElse noFatal(rootMirror staticClass path) + ) + def getModuleIfDefined(path: String) = ( + noFatal(runtimeMirror staticModule path) + orElse noFatal(rootMirror staticModule path) + ) + + implicit class ReplTypeOps(tp: Type) { + def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) + } + + // TODO: If we try to make naming a lazy val, we run into big time + // scalac unhappiness with what look like cycles. It has not been easy to + // reduce, but name resolution clearly takes different paths. + object naming extends { + val global: imain.global.type = imain.global + } with Naming { + // make sure we don't overwrite their unwisely named res3 etc. + def freshUserTermName(): TermName = { + val name = newTermName(freshUserVarName()) + if (replScope containsName name) freshUserTermName() + else name + } + def isInternalTermName(name: Name) = isInternalVarName("" + name) + } + import naming._ + + object deconstruct extends { + val global: imain.global.type = imain.global + } with StructuredTypeStrings + + lazy val memberHandlers = new { + val intp: imain.type = imain + } with SparkMemberHandlers + import memberHandlers._ + + /** Temporarily be quiet */ + def beQuietDuring[T](body: => T): T = { + val saved = printResults + printResults = false + try body + finally printResults = saved + } + def beSilentDuring[T](operation: => T): T = { + val saved = totalSilence + totalSilence = true + try operation + finally totalSilence = saved + } + + def quietRun[T](code: String) = beQuietDuring(interpret(code)) + + /** takes AnyRef because it may be binding a Throwable or an Exceptional */ + private def withLastExceptionLock[T](body: => T, alt: => T): T = { + assert(bindExceptions, "withLastExceptionLock called incorrectly.") + bindExceptions = false + + try beQuietDuring(body) + catch logAndDiscard("withLastExceptionLock", alt) + finally bindExceptions = true + } + + def executionWrapper = _executionWrapper + def setExecutionWrapper(code: String) = _executionWrapper = code + def clearExecutionWrapper() = _executionWrapper = "" + + /** interpreter settings */ + lazy val isettings = new SparkISettings(this) + + /** Instantiate a compiler. Overridable. */ + protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = { + settings.outputDirs setSingleOutput replOutput.dir + settings.exposeEmptyPackage.value = true + new Global(settings, reporter) with ReplGlobal { override def toString: String = "" } + } + + /** Parent classloader. Overridable. */ + protected def parentClassLoader: ClassLoader = + settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() ) + + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + def resetClassLoader() = { + repldbg("Setting new classloader: was " + _classLoader) + _classLoader = null + ensureClassLoader() + } + final def ensureClassLoader() { + if (_classLoader == null) + _classLoader = makeClassLoader() + } + def classLoader: util.AbstractFileClassLoader = { + ensureClassLoader() + _classLoader + } + + def backticked(s: String): String = ( + (s split '.').toList map { + case "_" => "_" + case s if nme.keywords(newTermName(s)) => s"`$s`" + case s => s + } mkString "." + ) + def readRootPath(readPath: String) = getModuleIfDefined(readPath) + + abstract class PhaseDependentOps { + def shift[T](op: => T): T + + def path(name: => Name): String = shift(path(symbolOfName(name))) + def path(sym: Symbol): String = backticked(shift(sym.fullName)) + def sig(sym: Symbol): String = shift(sym.defString) + } + object typerOp extends PhaseDependentOps { + def shift[T](op: => T): T = exitingTyper(op) + } + object flatOp extends PhaseDependentOps { + def shift[T](op: => T): T = exitingFlatten(op) + } + + def originalPath(name: String): String = originalPath(name: TermName) + def originalPath(name: Name): String = typerOp path name + def originalPath(sym: Symbol): String = typerOp path sym + def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName + def translatePath(path: String) = { + val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path) + sym.toOption map flatPath + } + def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath + + private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) { + /** Overridden here to try translating a simple name to the generated + * class name if the original attempt fails. This method is used by + * getResourceAsStream as well as findClass. + */ + override protected def findAbstractFile(name: String): AbstractFile = + super.findAbstractFile(name) match { + case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull + case file => file + } + } + private def makeClassLoader(): util.AbstractFileClassLoader = + new TranslatingClassLoader(parentClassLoader match { + case null => ScalaClassLoader fromURLs compilerClasspath + case p => new ScalaClassLoader.URLClassLoader(compilerClasspath, p) + }) + + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = classLoader.setAsContext() + + def allDefinedNames: List[Name] = exitingTyper(replScope.toList.map(_.name).sorted) + def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted + + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + prevRequests.reverse foreach { req => + req.handlers.reverse foreach { + case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) + case _ => () + } + } + None + } + + private def updateReplScope(sym: Symbol, isDefined: Boolean) { + def log(what: String) { + val mark = if (sym.isType) "t " else "v " + val name = exitingTyper(sym.nameString) + val info = cleanTypeAfterTyper(sym) + val defn = sym defStringSeenAs info + + scopelog(f"[$mark$what%6s] $name%-25s $defn%s") + } + if (ObjectClass isSubClass sym.owner) return + // unlink previous + replScope lookupAll sym.name foreach { sym => + log("unlink") + replScope unlink sym + } + val what = if (isDefined) "define" else "import" + log(what) + replScope enter sym + } + + def recordRequest(req: Request) { + if (req == null) + return + + prevRequests += req + + // warning about serially defining companions. It'd be easy + // enough to just redefine them together but that may not always + // be what people want so I'm waiting until I can do it better. + exitingTyper { + req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym => + val oldSym = replScope lookup newSym.name.companionName + if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) { + replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.") + replwarn("Companions must be defined together; you may wish to use :paste mode for this.") + } + } + } + exitingTyper { + req.imports foreach (sym => updateReplScope(sym, isDefined = false)) + req.defines foreach (sym => updateReplScope(sym, isDefined = true)) + } + } + + private[nsc] def replwarn(msg: => String) { + if (!settings.nowarnings) + printMessage(msg) + } + + def compileSourcesKeepingRun(sources: SourceFile*) = { + val run = new Run() + assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") + reporter.reset() + run compileSources sources.toList + (!reporter.hasErrors, run) + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false otherwise. + */ + def compileSources(sources: SourceFile*): Boolean = + compileSourcesKeepingRun(sources: _*)._1 + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(new BatchSourceFile("