diff --git a/.rat-excludes b/.rat-excludes index eaefef1b0aa2e..fb6323daf9211 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -31,6 +31,7 @@ sorttable.js .*data .*log cloudpickle.py +heapq3.py join.py SparkExprTyper.scala SparkILoop.scala diff --git a/LICENSE b/LICENSE index e9a1153fdc5db..a7eee041129cb 100644 --- a/LICENSE +++ b/LICENSE @@ -338,6 +338,289 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +======================================================================== +For heapq (pyspark/heapq3.py): +======================================================================== + +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ======================================================================== For sorttable (core/src/main/resources/org/apache/spark/ui/static/sorttable.js): diff --git a/README.md b/README.md index a1a48f5bd0819..5b09ad86849e7 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLLib for machine learning, GraphX for graph processing, -and Spark Streaming. +data processing, MLlib for machine learning, GraphX for graph processing, +and Spark Streaming for stream processing. @@ -69,7 +69,7 @@ Many of the example programs print usage help if no params are given. Testing first requires [building Spark](#building-spark). Once Spark is built, tests can be run using: - ./sbt/sbt test + ./dev/run-tests ## A Note About Hadoop Versions @@ -118,11 +118,10 @@ If your project is built with Maven, add this to your POM file's ` ## A Note About Thrift JDBC server and CLI for Spark SQL Spark SQL supports Thrift JDBC server and CLI. -See sql-programming-guide.md for more information about those features. -You can use those features by setting `-Phive-thriftserver` when building Spark as follows. - - $ sbt/sbt -Phive-thriftserver assembly +See sql-programming-guide.md for more information about using the JDBC server and CLI. +You can use those features by setting `-Phive` when building Spark as follows. + $ sbt/sbt -Phive assembly ## Configuration @@ -140,3 +139,5 @@ submitting any copyrighted material via pull request, email, or other means you agree to license the material under the project's open source license and warrant that you have the legal authority to do so. +Please see [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) +for more information. diff --git a/assembly/pom.xml b/assembly/pom.xml index 703f15925bc44..4146168fc804b 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml @@ -43,6 +43,12 @@ + + + com.google.guava + guava + compile + org.apache.spark spark-core_${scala.binary.version} @@ -113,6 +119,18 @@ shade + + + com.google + org.spark-project.guava + + com.google.common.** + + + com.google.common.base.Optional** + + + @@ -163,11 +181,6 @@ spark-hive_${scala.binary.version} ${project.version} - - - - hive-thriftserver - org.apache.spark spark-hive-thriftserver_${scala.binary.version} diff --git a/bagel/pom.xml b/bagel/pom.xml index bd51b112e26fa..93db0d5efda5f 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index 55241d33cd3f0..ccb262a4ee02a 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -24,8 +24,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.storage.StorageLevel -import scala.language.postfixOps - class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable diff --git a/bin/beeline b/bin/beeline index 1bda4dba50605..3fcb6df34339d 100755 --- a/bin/beeline +++ b/bin/beeline @@ -24,7 +24,7 @@ set -o posix # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" CLASS="org.apache.hive.beeline.BeeLine" exec "$FWDIR/bin/spark-class" $CLASS "$@" diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 58710cd1bd548..5ad52452a5c98 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -36,7 +36,8 @@ rem Load environment variables from conf\spark-env.cmd, if it exists if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" rem Build up classpath -set CLASSPATH=%FWDIR%conf +set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%;%FWDIR%conf + if exist "%FWDIR%RELEASE" ( for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( set ASSEMBLY_JAR=%%d diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 16b794a1592e8..15c6779402994 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -23,9 +23,9 @@ SCALA_VERSION=2.10 # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -. $FWDIR/bin/load-spark-env.sh +. "$FWDIR"/bin/load-spark-env.sh # Build up classpath CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH:$FWDIR/conf" @@ -63,7 +63,7 @@ else assembly_folder="$ASSEMBLY_DIR" fi -num_jars=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l) +num_jars="$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l)" if [ "$num_jars" -eq "0" ]; then echo "Failed to find Spark assembly in $assembly_folder" echo "You need to build Spark before running this program." @@ -77,7 +77,7 @@ if [ "$num_jars" -gt "1" ]; then exit 1 fi -ASSEMBLY_JAR=$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null) +ASSEMBLY_JAR="$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null)" # Verify that versions of java used to build the jars and run Spark are compatible jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) @@ -103,8 +103,8 @@ else datanucleus_dir="$FWDIR"/lib_managed/jars fi -datanucleus_jars=$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar") -datanucleus_jars=$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g) +datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar")" +datanucleus_jars="$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g)" if [ -n "$datanucleus_jars" ]; then hive_files=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index d425f9feaac54..6d4231b204595 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -25,14 +25,14 @@ if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 # Returns the parent of the directory this script lives in. - parent_dir="$(cd `dirname $0`/..; pwd)" + parent_dir="$(cd "`dirname "$0"`"/..; pwd)" - use_conf_dir=${SPARK_CONF_DIR:-"$parent_dir/conf"} + user_conf_dir="${SPARK_CONF_DIR:-"$parent_dir"/conf}" - if [ -f "${use_conf_dir}/spark-env.sh" ]; then + if [ -f "${user_conf_dir}/spark-env.sh" ]; then # Promote all variable declarations to environment (exported) variables set -a - . "${use_conf_dir}/spark-env.sh" + . "${user_conf_dir}/spark-env.sh" set +a fi fi diff --git a/bin/pyspark b/bin/pyspark index 01d42025c978e..5142411e36974 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,18 +18,18 @@ # # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" -source $FWDIR/bin/utils.sh +source "$FWDIR/bin/utils.sh" SCALA_VERSION=2.10 function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 exit 0 } @@ -48,7 +48,7 @@ if [ ! -f "$FWDIR/RELEASE" ]; then fi fi -. $FWDIR/bin/load-spark-env.sh +. "$FWDIR"/bin/load-spark-env.sh # Figure out which Python executable to use if [[ -z "$PYSPARK_PYTHON" ]]; then @@ -57,12 +57,12 @@ fi export PYSPARK_PYTHON # Add the PySpark classes to the Python path: -export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH +export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" +export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: -export OLD_PYTHONSTARTUP=$PYTHONSTARTUP -export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py +export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" +export PYTHONSTARTUP="$FWDIR/python/pyspark/shell.py" # If IPython options are specified, assume user wants to run IPython if [[ -n "$IPYTHON_OPTS" ]]; then @@ -85,6 +85,8 @@ export PYSPARK_SUBMIT_ARGS # For pyspark tests if [[ -n "$SPARK_TESTING" ]]; then + unset YARN_CONF_DIR + unset HADOOP_CONF_DIR if [[ -n "$PYSPARK_DOC_TEST" ]]; then exec "$PYSPARK_PYTHON" -m doctest $1 else @@ -97,14 +99,16 @@ fi if [[ "$1" =~ \.py$ ]]; then echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 echo -e "Use ./bin/spark-submit \n" 1>&2 - primary=$1 + primary="$1" shift gatherSparkSubmitOpts "$@" - exec $FWDIR/bin/spark-submit "${SUBMISSION_OPTS[@]}" $primary "${APPLICATION_OPTS[@]}" + exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}" else + # PySpark shell requires special handling downstream + export PYSPARK_SHELL=1 # Only use ipython if no command line arguments were provided [SPARK-1134] if [[ "$IPYTHON" = "1" ]]; then - exec ipython $IPYTHON_OPTS + exec ${PYSPARK_PYTHON:-ipython} $IPYTHON_OPTS else exec "$PYSPARK_PYTHON" fi diff --git a/bin/run-example b/bin/run-example index 68a35702eddd3..34dd71c71880e 100755 --- a/bin/run-example +++ b/bin/run-example @@ -19,7 +19,7 @@ SCALA_VERSION=2.10 -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" export SPARK_HOME="$FWDIR" EXAMPLES_DIR="$FWDIR"/examples @@ -35,12 +35,12 @@ else fi if [ -f "$FWDIR/RELEASE" ]; then - export SPARK_EXAMPLES_JAR=`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar` + export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`" elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then - export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar` + export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`" fi -if [[ -z $SPARK_EXAMPLES_JAR ]]; then +if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 exit 1 diff --git a/bin/spark-class b/bin/spark-class index 3f6beca5becf0..5f5f9ea74888d 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -17,6 +17,8 @@ # limitations under the License. # +# NOTE: Any changes to this file must be reflected in SparkSubmitDriverBootstrapper.scala! + cygwin=false case "`uname`" in CYGWIN*) cygwin=true;; @@ -25,12 +27,12 @@ esac SCALA_VERSION=2.10 # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" -. $FWDIR/bin/load-spark-env.sh +. "$FWDIR"/bin/load-spark-env.sh if [ -z "$1" ]; then echo "Usage: spark-class []" 1>&2 @@ -39,7 +41,7 @@ fi if [ -n "$SPARK_MEM" ]; then echo -e "Warning: SPARK_MEM is deprecated, please use a more specific config option" 1>&2 - echo -e "(e.g., spark.executor.memory or SPARK_DRIVER_MEMORY)." 1>&2 + echo -e "(e.g., spark.executor.memory or spark.driver.memory)." 1>&2 fi # Use SPARK_MEM or 512m as the default memory, to be overridden by specific options @@ -73,11 +75,17 @@ case "$1" in OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} ;; - # Spark submit uses SPARK_SUBMIT_OPTS and SPARK_JAVA_OPTS - 'org.apache.spark.deploy.SparkSubmit') - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS \ - -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH" + # Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS + + # SPARK_DRIVER_MEMORY + SPARK_SUBMIT_DRIVER_MEMORY. + 'org.apache.spark.deploy.SparkSubmit') + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS" OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} + if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then + OUR_JAVA_OPTS="$OUR_JAVA_OPTS -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH" + fi + if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then + OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY" + fi ;; *) @@ -97,36 +105,42 @@ else exit 1 fi fi +JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') # Set JAVA_OPTS to be able to load native libraries and to set heap size -JAVA_OPTS="-XX:MaxPermSize=128m $OUR_JAVA_OPTS" +if [ "$JAVA_VERSION" -ge 18 ]; then + JAVA_OPTS="$OUR_JAVA_OPTS" +else + JAVA_OPTS="-XX:MaxPermSize=128m $OUR_JAVA_OPTS" +fi JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" + # Load extra JAVA_OPTS from conf/java-opts, if it exists if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" + JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`" fi -export JAVA_OPTS + # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the SBT build - export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar` + export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`" fi if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the Maven build # TODO: this also needs to become an assembly! - export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar` + export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar`" fi # Compute classpath using external script -classpath_output=$($FWDIR/bin/compute-classpath.sh) +classpath_output=$("$FWDIR"/bin/compute-classpath.sh) if [[ "$?" != "0" ]]; then echo "$classpath_output" exit 1 else - CLASSPATH=$classpath_output + CLASSPATH="$classpath_output" fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then @@ -139,17 +153,35 @@ if [[ "$1" =~ org.apache.spark.tools.* ]]; then fi if $cygwin; then - CLASSPATH=`cygpath -wp $CLASSPATH` + CLASSPATH="`cygpath -wp "$CLASSPATH"`" if [ "$1" == "org.apache.spark.tools.JavaAPICompletenessChecker" ]; then - export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR` + export SPARK_TOOLS_JAR="`cygpath -w "$SPARK_TOOLS_JAR"`" fi fi export CLASSPATH -if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then - echo -n "Spark Command: " 1>&2 - echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" 1>&2 - echo -e "========================================\n" 1>&2 +# In Spark submit client mode, the driver is launched in the same JVM as Spark submit itself. +# Here we must parse the properties file for relevant "spark.driver.*" configs before launching +# the driver JVM itself. Instead of handling this complexity in Bash, we launch a separate JVM +# to prepare the launch environment of this driver JVM. + +if [ -n "$SPARK_SUBMIT_BOOTSTRAP_DRIVER" ]; then + # This is used only if the properties file actually contains these special configs + # Export the environment variables needed by SparkSubmitDriverBootstrapper + export RUNNER + export CLASSPATH + export JAVA_OPTS + export OUR_JAVA_MEM + export SPARK_CLASS=1 + shift # Ignore main class (org.apache.spark.deploy.SparkSubmit) and use our own + exec "$RUNNER" org.apache.spark.deploy.SparkSubmitDriverBootstrapper "$@" +else + # Note: The format of this command is closely echoed in SparkSubmitDriverBootstrapper.scala + if [ -n "$SPARK_PRINT_LAUNCH_COMMAND" ]; then + echo -n "Spark Command: " 1>&2 + echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" 1>&2 + echo -e "========================================\n" 1>&2 + fi + exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" fi -exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd old mode 100755 new mode 100644 index e420eb409e529..6c5672819172b --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -17,6 +17,8 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem +rem Any changes to this file must be reflected in SparkSubmitDriverBootstrapper.scala! + setlocal enabledelayedexpansion set SCALA_VERSION=2.10 @@ -38,7 +40,7 @@ if not "x%1"=="x" goto arg_given if not "x%SPARK_MEM%"=="x" ( echo Warning: SPARK_MEM is deprecated, please use a more specific config option - echo e.g., spark.executor.memory or SPARK_DRIVER_MEMORY. + echo e.g., spark.executor.memory or spark.driver.memory. ) rem Use SPARK_MEM or 512m as the default memory, to be overridden by specific options @@ -67,17 +69,31 @@ rem Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% -rem All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS. -) else if "%1"=="org.apache.spark.repl.Main" ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_REPL_OPTS% +rem Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS + +rem SPARK_DRIVER_MEMORY + SPARK_SUBMIT_DRIVER_MEMORY. +rem The repl also uses SPARK_REPL_OPTS. +) else if "%1"=="org.apache.spark.deploy.SparkSubmit" ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_SUBMIT_OPTS% %SPARK_REPL_OPTS% + if not "x%SPARK_SUBMIT_LIBRARY_PATH%"=="x" ( + set OUR_JAVA_OPTS=!OUR_JAVA_OPTS! -Djava.library.path=%SPARK_SUBMIT_LIBRARY_PATH% + ) else if not "x%SPARK_LIBRARY_PATH%"=="x" ( + set OUR_JAVA_OPTS=!OUR_JAVA_OPTS! -Djava.library.path=%SPARK_LIBRARY_PATH% + ) if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% + if not "x%SPARK_SUBMIT_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_SUBMIT_DRIVER_MEMORY% ) else ( set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% ) rem Set JAVA_OPTS to be able to load native libraries and to set heap size -set JAVA_OPTS=-XX:MaxPermSize=128m %OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% +for /f "tokens=3" %%i in ('java -version 2^>^&1 ^| find "version"') do set jversion=%%i +for /f "tokens=1 delims=_" %%i in ("%jversion:~1,-1%") do set jversion=%%i +if "%jversion%" geq "1.8.0" ( + set JAVA_OPTS=%OUR_JAVA_OPTS% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% +) else ( + set JAVA_OPTS=-XX:MaxPermSize=128m %OUR_JAVA_OPTS% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% +) rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! rem Test whether the user has built Spark @@ -109,5 +125,27 @@ rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java -"%RUNNER%" -cp "%CLASSPATH%" %JAVA_OPTS% %* +rem In Spark submit client mode, the driver is launched in the same JVM as Spark submit itself. +rem Here we must parse the properties file for relevant "spark.driver.*" configs before launching +rem the driver JVM itself. Instead of handling this complexity here, we launch a separate JVM +rem to prepare the launch environment of this driver JVM. + +rem In this case, leave out the main class (org.apache.spark.deploy.SparkSubmit) and use our own. +rem Leaving out the first argument is surprisingly difficult to do in Windows. Note that this must +rem be done here because the Windows "shift" command does not work in a conditional block. +set BOOTSTRAP_ARGS= +shift +:start_parse +if "%~1" == "" goto end_parse +set BOOTSTRAP_ARGS=%BOOTSTRAP_ARGS% %~1 +shift +goto start_parse +:end_parse + +if not [%SPARK_SUBMIT_BOOTSTRAP_DRIVER%] == [] ( + set SPARK_CLASS=1 + "%RUNNER%" org.apache.spark.deploy.SparkSubmitDriverBootstrapper %BOOTSTRAP_ARGS% +) else ( + "%RUNNER%" -cp "%CLASSPATH%" %JAVA_OPTS% %* +) :exit diff --git a/bin/spark-shell b/bin/spark-shell index 8b7ccd7439551..4a0670fc6c8aa 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -22,44 +22,44 @@ cygwin=false case "`uname`" in - CYGWIN*) cygwin=true;; + CYGWIN*) cygwin=true;; esac # Enter posix mode for bash set -o posix ## Global script variables -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" function usage() { - echo "Usage: ./bin/spark-shell [options]" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 + echo "Usage: ./bin/spark-shell [options]" + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 } if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then usage fi -source $FWDIR/bin/utils.sh +source "$FWDIR"/bin/utils.sh SUBMIT_USAGE_FUNCTION=usage gatherSparkSubmitOpts "$@" function main() { - if $cygwin; then - # Workaround for issue involving JLine and Cygwin - # (see http://sourceforge.net/p/jline/bugs/40/). - # If you're using the Mintty terminal emulator in Cygwin, may need to set the - # "Backspace sends ^H" setting in "Keys" section of the Mintty options - # (see https://github.com/sbt/sbt/issues/562). - stty -icanon min 1 -echo > /dev/null 2>&1 - export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" - stty icanon echo > /dev/null 2>&1 - else - export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" - fi + if $cygwin; then + # Workaround for issue involving JLine and Cygwin + # (see http://sourceforge.net/p/jline/bugs/40/). + # If you're using the Mintty terminal emulator in Cygwin, may need to set the + # "Backspace sends ^H" setting in "Keys" section of the Mintty options + # (see https://github.com/sbt/sbt/issues/562). + stty -icanon min 1 -echo > /dev/null 2>&1 + export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" + stty icanon echo > /dev/null 2>&1 + else + export SPARK_SUBMIT_OPTS + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" + fi } # Copy restore-TTY-on-exit functions from Scala script so spark-shell exits properly even in diff --git a/bin/spark-sql b/bin/spark-sql index 564f1f419060f..ae096530cad04 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -24,9 +24,10 @@ set -o posix CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" +CLASS_NOT_FOUND_EXIT_STATUS=1 # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" function usage { echo "Usage: ./bin/spark-sql [options] [cli option]" @@ -37,58 +38,28 @@ function usage { pattern+="\|--help" pattern+="\|=======" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 echo echo "CLI options:" - $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 } -function ensure_arg_number { - arg_number=$1 - at_least=$2 - - if [[ $arg_number -lt $at_least ]]; then - usage - exit 1 - fi -} - -if [[ "$@" = --help ]] || [[ "$@" = -h ]]; then +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then usage exit 0 fi -CLI_ARGS=() -SUBMISSION_ARGS=() - -while (($#)); do - case $1 in - -d | --define | --database | -f | -h | --hiveconf | --hivevar | -i | -p) - ensure_arg_number $# 2 - CLI_ARGS+=("$1"); shift - CLI_ARGS+=("$1"); shift - ;; +source "$FWDIR"/bin/utils.sh +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" - -e) - ensure_arg_number $# 2 - CLI_ARGS+=("$1"); shift - CLI_ARGS+=("$1"); shift - ;; +"$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" +exit_status=$? - -s | --silent) - CLI_ARGS+=("$1"); shift - ;; - - -v | --verbose) - # Both SparkSubmit and SparkSQLCLIDriver recognizes -v | --verbose - CLI_ARGS+=("$1") - SUBMISSION_ARGS+=("$1"); shift - ;; - - *) - SUBMISSION_ARGS+=("$1"); shift - ;; - esac -done +if [[ exit_status -eq CLASS_NOT_FOUND_EXIT_STATUS ]]; then + echo + echo "Failed to load Spark SQL CLI main class $CLASS." + echo "You need to build Spark with -Phive." +fi -exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${CLI_ARGS[@]}" +exit $exit_status diff --git a/bin/spark-submit b/bin/spark-submit index 9e7cecedd0325..c557311b4b20e 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -17,14 +17,18 @@ # limitations under the License. # -export SPARK_HOME="$(cd `dirname $0`/..; pwd)" +# NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! + +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" ORIG_ARGS=("$@") while (($#)); do if [ "$1" = "--deploy-mode" ]; then - DEPLOY_MODE=$2 + SPARK_SUBMIT_DEPLOY_MODE=$2 + elif [ "$1" = "--properties-file" ]; then + SPARK_SUBMIT_PROPERTIES_FILE=$2 elif [ "$1" = "--driver-memory" ]; then - DRIVER_MEMORY=$2 + export SPARK_SUBMIT_DRIVER_MEMORY=$2 elif [ "$1" = "--driver-library-path" ]; then export SPARK_SUBMIT_LIBRARY_PATH=$2 elif [ "$1" = "--driver-class-path" ]; then @@ -35,11 +39,25 @@ while (($#)); do shift done -DEPLOY_MODE=${DEPLOY_MODE:-"client"} +DEFAULT_PROPERTIES_FILE="$SPARK_HOME/conf/spark-defaults.conf" +export SPARK_SUBMIT_DEPLOY_MODE=${SPARK_SUBMIT_DEPLOY_MODE:-"client"} +export SPARK_SUBMIT_PROPERTIES_FILE=${SPARK_SUBMIT_PROPERTIES_FILE:-"$DEFAULT_PROPERTIES_FILE"} + +# For client mode, the driver will be launched in the same JVM that launches +# SparkSubmit, so we may need to read the properties file for any extra class +# paths, library paths, java options and memory early on. Otherwise, it will +# be too late by the time the driver JVM has started. -if [ -n "$DRIVER_MEMORY" ] && [ $DEPLOY_MODE == "client" ]; then - export SPARK_DRIVER_MEMORY=$DRIVER_MEMORY +if [[ "$SPARK_SUBMIT_DEPLOY_MODE" == "client" && -f "$SPARK_SUBMIT_PROPERTIES_FILE" ]]; then + # Parse the properties file only if the special configs exist + contains_special_configs=$( + grep -e "spark.driver.extra*\|spark.driver.memory" "$SPARK_SUBMIT_PROPERTIES_FILE" | \ + grep -v "^[[:space:]]*#" + ) + if [ -n "$contains_special_configs" ]; then + export SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 + fi fi -exec $SPARK_HOME/bin/spark-class org.apache.spark.deploy.SparkSubmit "${ORIG_ARGS[@]}" +exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "${ORIG_ARGS[@]}" diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index 6eb702ed8c561..cf6046d1547ad 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -17,23 +17,28 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem +rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! + set SPARK_HOME=%~dp0.. set ORIG_ARGS=%* -rem Clear the values of all variables used -set DEPLOY_MODE= -set DRIVER_MEMORY= +rem Reset the values of all variables used +set SPARK_SUBMIT_DEPLOY_MODE=client +set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf +set SPARK_SUBMIT_DRIVER_MEMORY= set SPARK_SUBMIT_LIBRARY_PATH= set SPARK_SUBMIT_CLASSPATH= set SPARK_SUBMIT_OPTS= -set SPARK_DRIVER_MEMORY= +set SPARK_SUBMIT_BOOTSTRAP_DRIVER= :loop if [%1] == [] goto continue if [%1] == [--deploy-mode] ( - set DEPLOY_MODE=%2 + set SPARK_SUBMIT_DEPLOY_MODE=%2 + ) else if [%1] == [--properties-file] ( + set SPARK_SUBMIT_PROPERTIES_FILE=%2 ) else if [%1] == [--driver-memory] ( - set DRIVER_MEMORY=%2 + set SPARK_SUBMIT_DRIVER_MEMORY=%2 ) else if [%1] == [--driver-library-path] ( set SPARK_SUBMIT_LIBRARY_PATH=%2 ) else if [%1] == [--driver-class-path] ( @@ -45,12 +50,19 @@ if [%1] == [] goto continue goto loop :continue -if [%DEPLOY_MODE%] == [] ( - set DEPLOY_MODE=client -) +rem For client mode, the driver will be launched in the same JVM that launches +rem SparkSubmit, so we may need to read the properties file for any extra class +rem paths, library paths, java options and memory early on. Otherwise, it will +rem be too late by the time the driver JVM has started. -if not [%DRIVER_MEMORY%] == [] if [%DEPLOY_MODE%] == [client] ( - set SPARK_DRIVER_MEMORY=%DRIVER_MEMORY% +if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( + if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( + rem Parse the properties file only if the special configs exist + for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ + %SPARK_SUBMIT_PROPERTIES_FILE%') do ( + set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 + ) + ) ) cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% diff --git a/bin/utils.sh b/bin/utils.sh old mode 100644 new mode 100755 diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index 2779342769c14..a48dcc70e1363 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -2,7 +2,9 @@ # This is useful for setting default environmental settings. # Example: -# spark.master spark://master:7077 -# spark.eventLog.enabled true -# spark.eventLog.dir hdfs://namenode:8021/directory -# spark.serializer org.apache.spark.serializer.KryoSerializer +# spark.master spark://master:7077 +# spark.eventLog.enabled true +# spark.eventLog.dir hdfs://namenode:8021/directory +# spark.serializer org.apache.spark.serializer.KryoSerializer +# spark.driver.memory 5g +# spark.executor.extraJavaOptions -XX:+PrintGCDetails -Dkey=value -Dnumbers="one two three" diff --git a/core/pom.xml b/core/pom.xml index 6d8be37037729..b2b788a4bc13b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml @@ -68,9 +68,15 @@ org.eclipse.jetty jetty-server + com.google.guava guava + compile org.apache.commons @@ -300,28 +306,51 @@ - org.codehaus.mojo - exec-maven-plugin - 1.2.1 + org.apache.maven.plugins + maven-antrun-plugin generate-resources - exec + run - unzip - ../python - - -o - lib/py4j*.zip - -d - build - + + + + + org.apache.maven.plugins + maven-shade-plugin + + + package + + shade + + + false + + + com.google.guava:guava + + + + + + com.google.guava:guava + + com/google/common/base/Optional* + + + + + + + diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index 7abb9011ccf36..dbacbf19beee5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -81,15 +81,15 @@ sorttable = { if (!headrow[i].className.match(/\bsorttable_nosort\b/)) { // skip this col mtch = headrow[i].className.match(/\bsorttable_([a-z0-9]+)\b/); if (mtch) { override = mtch[1]; } - if (mtch && typeof sorttable["sort_"+override] == 'function') { - headrow[i].sorttable_sortfunction = sorttable["sort_"+override]; - } else { - headrow[i].sorttable_sortfunction = sorttable.guessType(table,i); - } - // make it clickable to sort - headrow[i].sorttable_columnindex = i; - headrow[i].sorttable_tbody = table.tBodies[0]; - dean_addEvent(headrow[i],"click", function(e) { + if (mtch && typeof sorttable["sort_"+override] == 'function') { + headrow[i].sorttable_sortfunction = sorttable["sort_"+override]; + } else { + headrow[i].sorttable_sortfunction = sorttable.guessType(table,i); + } + // make it clickable to sort + headrow[i].sorttable_columnindex = i; + headrow[i].sorttable_tbody = table.tBodies[0]; + dean_addEvent(headrow[i],"click", function(e) { if (this.className.search(/\bsorttable_sorted\b/) != -1) { // if we're already sorted by this column, just @@ -109,7 +109,7 @@ sorttable = { // re-reverse the table, which is quicker sorttable.reverse(this.sorttable_tbody); this.className = this.className.replace('sorttable_sorted_reverse', - 'sorttable_sorted'); + 'sorttable_sorted'); this.removeChild(document.getElementById('sorttable_sortrevind')); sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; @@ -117,7 +117,7 @@ sorttable = { this.appendChild(sortfwdind); return; } - + // remove sorttable_sorted classes theadrow = this.parentNode; forEach(theadrow.childNodes, function(cell) { @@ -130,36 +130,36 @@ sorttable = { if (sortfwdind) { sortfwdind.parentNode.removeChild(sortfwdind); } sortrevind = document.getElementById('sorttable_sortrevind'); if (sortrevind) { sortrevind.parentNode.removeChild(sortrevind); } - + this.className += ' sorttable_sorted'; sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; this.appendChild(sortfwdind); - // build an array to sort. This is a Schwartzian transform thing, - // i.e., we "decorate" each row with the actual sort key, - // sort based on the sort keys, and then put the rows back in order - // which is a lot faster because you only do getInnerText once per row - row_array = []; - col = this.sorttable_columnindex; - rows = this.sorttable_tbody.rows; - for (var j=0; j 0 ) { - var q = list[i]; list[i] = list[i+1]; list[i+1] = q; - swap = true; - } - } // for - t--; + swap = false; + for(var i = b; i < t; ++i) { + if ( comp_func(list[i], list[i+1]) > 0 ) { + var q = list[i]; list[i] = list[i+1]; list[i+1] = q; + swap = true; + } + } // for + t--; - if (!swap) break; - - for(var i = t; i > b; --i) { - if ( comp_func(list[i], list[i-1]) < 0 ) { - var q = list[i]; list[i] = list[i-1]; list[i-1] = q; - swap = true; - } - } // for - b++; + if (!swap) break; + for(var i = t; i > b; --i) { + if ( comp_func(list[i], list[i-1]) < 0 ) { + var q = list[i]; list[i] = list[i-1]; list[i-1] = q; + swap = true; + } + } // for + b++; } // while(swap) } } @@ -358,11 +357,11 @@ if (document.addEventListener) { /* for Safari */ if (/WebKit/i.test(navigator.userAgent)) { // sniff - var _timer = setInterval(function() { - if (/loaded|complete/.test(document.readyState)) { - sorttable.init(); // call the onload handler - } - }, 10); + var _timer = setInterval(function() { + if (/loaded|complete/.test(document.readyState)) { + sorttable.init(); // call the onload handler + } + }, 10); } /* for other browsers */ @@ -374,66 +373,66 @@ window.onload = sorttable.init; // http://dean.edwards.name/weblog/2005/10/add-event/ function dean_addEvent(element, type, handler) { - if (element.addEventListener) { - element.addEventListener(type, handler, false); - } else { - // assign each event handler a unique ID - if (!handler.$$guid) handler.$$guid = dean_addEvent.guid++; - // create a hash table of event types for the element - if (!element.events) element.events = {}; - // create a hash table of event handlers for each element/event pair - var handlers = element.events[type]; - if (!handlers) { - handlers = element.events[type] = {}; - // store the existing event handler (if there is one) - if (element["on" + type]) { - handlers[0] = element["on" + type]; - } - } - // store the event handler in the hash table - handlers[handler.$$guid] = handler; - // assign a global event handler to do all the work - element["on" + type] = handleEvent; - } + if (element.addEventListener) { + element.addEventListener(type, handler, false); + } else { + // assign each event handler a unique ID + if (!handler.$$guid) handler.$$guid = dean_addEvent.guid++; + // create a hash table of event types for the element + if (!element.events) element.events = {}; + // create a hash table of event handlers for each element/event pair + var handlers = element.events[type]; + if (!handlers) { + handlers = element.events[type] = {}; + // store the existing event handler (if there is one) + if (element["on" + type]) { + handlers[0] = element["on" + type]; + } + } + // store the event handler in the hash table + handlers[handler.$$guid] = handler; + // assign a global event handler to do all the work + element["on" + type] = handleEvent; + } }; // a counter used to create unique IDs dean_addEvent.guid = 1; function removeEvent(element, type, handler) { - if (element.removeEventListener) { - element.removeEventListener(type, handler, false); - } else { - // delete the event handler from the hash table - if (element.events && element.events[type]) { - delete element.events[type][handler.$$guid]; - } - } + if (element.removeEventListener) { + element.removeEventListener(type, handler, false); + } else { + // delete the event handler from the hash table + if (element.events && element.events[type]) { + delete element.events[type][handler.$$guid]; + } + } }; function handleEvent(event) { - var returnValue = true; - // grab the event object (IE uses a global event object) - event = event || fixEvent(((this.ownerDocument || this.document || this).parentWindow || window).event); - // get a reference to the hash table of event handlers - var handlers = this.events[event.type]; - // execute each event handler - for (var i in handlers) { - this.$$handleEvent = handlers[i]; - if (this.$$handleEvent(event) === false) { - returnValue = false; - } - } - return returnValue; + var returnValue = true; + // grab the event object (IE uses a global event object) + event = event || fixEvent(((this.ownerDocument || this.document || this).parentWindow || window).event); + // get a reference to the hash table of event handlers + var handlers = this.events[event.type]; + // execute each event handler + for (var i in handlers) { + this.$$handleEvent = handlers[i]; + if (this.$$handleEvent(event) === false) { + returnValue = false; + } + } + return returnValue; }; function fixEvent(event) { - // add W3C standard event methods - event.preventDefault = fixEvent.preventDefault; - event.stopPropagation = fixEvent.stopPropagation; - return event; + // add W3C standard event methods + event.preventDefault = fixEvent.preventDefault; + event.stopPropagation = fixEvent.stopPropagation; + return event; }; fixEvent.preventDefault = function() { - this.returnValue = false; + this.returnValue = false; }; fixEvent.stopPropagation = function() { this.cancelBubble = true; @@ -441,55 +440,55 @@ fixEvent.stopPropagation = function() { // Dean's forEach: http://dean.edwards.name/base/forEach.js /* - forEach, version 1.0 - Copyright 2006, Dean Edwards - License: http://www.opensource.org/licenses/mit-license.php +forEach, version 1.0 +Copyright 2006, Dean Edwards +License: http://www.opensource.org/licenses/mit-license.php */ // array-like enumeration if (!Array.forEach) { // mozilla already supports this - Array.forEach = function(array, block, context) { - for (var i = 0; i < array.length; i++) { - block.call(context, array[i], i, array); - } - }; + Array.forEach = function(array, block, context) { + for (var i = 0; i < array.length; i++) { + block.call(context, array[i], i, array); + } + }; } // generic enumeration Function.prototype.forEach = function(object, block, context) { - for (var key in object) { - if (typeof this.prototype[key] == "undefined") { - block.call(context, object[key], key, object); - } - } + for (var key in object) { + if (typeof this.prototype[key] == "undefined") { + block.call(context, object[key], key, object); + } + } }; // character enumeration String.forEach = function(string, block, context) { - Array.forEach(string.split(""), function(chr, index) { - block.call(context, chr, index, string); - }); + Array.forEach(string.split(""), function(chr, index) { + block.call(context, chr, index, string); + }); }; // globally resolve forEach enumeration var forEach = function(object, block, context) { - if (object) { - var resolve = Object; // default - if (object instanceof Function) { - // functions have a "length" property - resolve = Function; - } else if (object.forEach instanceof Function) { - // the object implements a custom forEach method so use that - object.forEach(block, context); - return; - } else if (typeof object == "string") { - // the object is a string - resolve = String; - } else if (typeof object.length == "number") { - // the object is array-like - resolve = Array; - } - resolve.forEach(object, block, context); - } + if (object) { + var resolve = Object; // default + if (object instanceof Function) { + // functions have a "length" property + resolve = Function; + } else if (object.forEach instanceof Function) { + // the object implements a custom forEach method so use that + object.forEach(block, context); + return; + } else if (typeof object == "string") { + // the object is a string + resolve = String; + } else if (typeof object.length == "number") { + // the object is array-like + resolve = Array; + } + resolve.forEach(object, block, context); + } }; diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 5ddda4d6953fa..f8584b90cabe6 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -68,7 +68,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // Otherwise, cache the values and keep track of any updates in block statuses val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks) - context.taskMetrics.updatedBlocks = Some(updatedBlocks) + val metrics = context.taskMetrics + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ updatedBlocks.toSeq) new InterruptibleIterator(context, cachedValues) } finally { diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 3848734d6f639..ede1e23f4fcc5 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -65,7 +65,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} /** - * Whether the cleaning thread will block on cleanup tasks. + * Whether the cleaning thread will block on cleanup tasks (other than shuffle, which + * is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter). * * Due to SPARK-3015, this is set to true by default. This is intended to be only a temporary * workaround for the issue, which is ultimately caused by the way the BlockManager actors @@ -76,6 +77,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val blockOnCleanupTasks = sc.conf.getBoolean( "spark.cleaner.referenceTracking.blocking", true) + /** + * Whether the cleaning thread will block on shuffle cleanup tasks. + * + * When context cleaner is configured to block on every delete request, it can throw timeout + * exceptions on cleanup of shuffle blocks, as reported in SPARK-3139. To avoid that, this + * parameter by default disables blocking on shuffle cleanups. Note that this does not affect + * the cleanup of RDDs and broadcasts. This is intended to be a temporary workaround, + * until the real Akka issue (referred to in the comment above `blockOnCleanupTasks`) is + * resolved. + */ + private val blockOnShuffleCleanupTasks = sc.conf.getBoolean( + "spark.cleaner.referenceTracking.blocking.shuffle", false) + @volatile private var stopped = false /** Attach a listener object to get information of when objects are cleaned. */ @@ -128,7 +142,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = blockOnCleanupTasks) case CleanShuffle(shuffleId) => - doCleanupShuffle(shuffleId, blocking = blockOnCleanupTasks) + doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1e4dec86a0530..75ea535f2f57b 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -149,6 +149,9 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: case JobFailed(e: Exception) => scala.util.Failure(e) } } + + /** Get the corresponding job id for this action. */ + def jobId = jobWaiter.jobId } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 25c2c9fc6af7c..3832a780ec4bc 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -162,7 +162,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { // always add the current user and SPARK_USER to the viewAcls private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), - Option(System.getenv("SPARK_USER")).getOrElse("")) + Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty) setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) @@ -294,7 +294,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - if (aclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true + !aclsEnabled || user == null || viewAcls.contains(user) } /** @@ -309,7 +309,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + modifyAcls.mkString(",")) - if (aclsEnabled() && (user != null) && (!modifyAcls.contains(user))) false else true + !aclsEnabled || user == null || modifyAcls.contains(user) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e132955f0f850..218b353dd9d49 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -220,30 +220,17 @@ class SparkContext(config: SparkConf) extends Logging { new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) // Initialize the Spark UI, registering all associated listeners - private[spark] val ui = new SparkUI(this) - ui.bind() + private[spark] val ui: Option[SparkUI] = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(new SparkUI(this)) + } else { + // For tests, do not enable the UI + None + } + ui.foreach(_.bind()) /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ - val hadoopConfiguration: Configuration = { - val hadoopConf = SparkHadoopUtil.get.newConfiguration() - // Explicitly check for S3 environment variables - if (System.getenv("AWS_ACCESS_KEY_ID") != null && - System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - } - // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - conf.getAll.foreach { case (key, value) => - if (key.startsWith("spark.hadoop.")) { - hadoopConf.set(key.substring("spark.hadoop.".length), value) - } - } - val bufferSize = conf.get("spark.buffer.size", "65536") - hadoopConf.set("io.file.buffer.size", bufferSize) - hadoopConf - } + val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) // Optionally log Spark events private[spark] val eventLogger: Option[EventLoggingListener] = { @@ -815,7 +802,7 @@ class SparkContext(config: SparkConf) extends Logging { * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, - * use `SparkFiles.get(path)` to find its download location. + * use `SparkFiles.get(fileName)` to find its download location. */ def addFile(path: String) { val uri = new URI(path) @@ -827,7 +814,8 @@ class SparkContext(config: SparkConf) extends Logging { addedFiles(key) = System.currentTimeMillis // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, + hadoopConfiguration) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) postEnvironmentUpdate() @@ -843,7 +831,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** The version of Spark on which this application is running. */ - def version = SparkContext.SPARK_VERSION + def version = SPARK_VERSION /** * Return a map from the slave to the max memory available for caching and the remaining @@ -1008,7 +996,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Shut down the SparkContext. */ def stop() { postApplicationEnd() - ui.stop() + ui.foreach(_.stop()) // Do this only if not stopped already - best case effort. // prevent NPE if stopped more than once. val dagSchedulerCopy = dagScheduler @@ -1279,7 +1267,10 @@ class SparkContext(config: SparkConf) extends Logging { /** Post the application start event */ private def postApplicationStart() { - listenerBus.post(SparkListenerApplicationStart(appName, startTime, sparkUser)) + // Note: this code assumes that the task scheduler has been initialized and has contacted + // the cluster manager to get an application ID (in case the cluster manager provides one). + listenerBus.post(SparkListenerApplicationStart(appName, taskScheduler.applicationId(), + startTime, sparkUser)) } /** Post the application end event */ @@ -1312,8 +1303,6 @@ class SparkContext(config: SparkConf) extends Logging { */ object SparkContext extends Logging { - private[spark] val SPARK_VERSION = "1.0.0" - private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" @@ -1637,4 +1626,3 @@ private[spark] class WritableConverter[T]( val writableClass: ClassTag[T] => Class[_ <: Writable], val convert: Writable => T) extends Serializable - diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index fc36e37c53f5e..dd95e406f2a8e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -31,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.network.ConnectionManager +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} @@ -59,8 +60,8 @@ class SparkEnv ( val mapOutputTracker: MapOutputTracker, val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, + val blockTransferService: BlockTransferService, val blockManager: BlockManager, - val connectionManager: ConnectionManager, val securityManager: SecurityManager, val httpFileServer: HttpFileServer, val sparkFilesDir: String, @@ -88,6 +89,8 @@ class SparkEnv ( // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. // actorSystem.awaitTermination() + + // Note that blockTransferService is stopped by BlockManager since it is started by it. } private[spark] @@ -111,6 +114,9 @@ object SparkEnv extends Logging { private val env = new ThreadLocal[SparkEnv] @volatile private var lastSetSparkEnv : SparkEnv = _ + private[spark] val driverActorSystemName = "sparkDriver" + private[spark] val executorActorSystemName = "sparkExecutor" + def set(e: SparkEnv) { lastSetSparkEnv = e env.set(e) @@ -146,9 +152,9 @@ object SparkEnv extends Logging { } val securityManager = new SecurityManager(conf) - - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf, - securityManager = securityManager) + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + actorSystemName, hostname, port, conf, securityManager) // Figure out which port Akka actually bound to in case the original port is 0 or occupied. // This is so that we tell the executors the correct port to connect to. @@ -214,20 +220,20 @@ object SparkEnv extends Logging { val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") - val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") + val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val blockTransferService = new NioBlockTransferService(conf, securityManager) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) + new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager, mapOutputTracker, shuffleManager) - - val connectionManager = blockManager.connectionManager + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) @@ -275,8 +281,8 @@ object SparkEnv extends Logging { mapOutputTracker, shuffleManager, broadcastManager, + blockTransferService, blockManager, - connectionManager, securityManager, httpFileServer, sparkFilesDir, diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index feeb6c02caa78..880f61c49726e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -758,6 +758,32 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) rdd.saveAsHadoopDataset(conf) } + /** + * Repartition the RDD according to the given partitioner and, within each resulting partition, + * sort records by their keys. + * + * This is more efficient than calling `repartition` and then sorting within each partition + * because it can push the sorting down into the shuffle machinery. + */ + def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] + repartitionAndSortWithinPartitions(partitioner, comp) + } + + /** + * Repartition the RDD according to the given partitioner and, within each resulting partition, + * sort records by their keys. + * + * This is more efficient than calling `repartition` and then sorting within each partition + * because it can push the sorting down into the shuffle machinery. + */ + def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K]) + : JavaPairRDD[K, V] = { + implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering. + fromRDD( + new OrderedRDDFunctions[K, V, (K, V)](rdd).repartitionAndSortWithinPartitions(partitioner)) + } + /** * Sort the RDD by key, so that each partition contains a sorted range of the elements in * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index f917cfd1419ec..545bc0e9e99ed 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag @@ -574,4 +574,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def name(): String = rdd.name + /** + * :: Experimental :: + * The asynchronous version of the foreach action. + * + * @param f the function to apply to all the elements of the RDD + * @return a FutureAction for the action + */ + @Experimental + def foreachAsync(f: VoidFunction[T]): FutureAction[Unit] = { + import org.apache.spark.SparkContext._ + rdd.foreachAsync(x => f.call(x)) + } + } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e0a4815940db3..8e178bc8480f7 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -545,7 +545,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, - * use `SparkFiles.get(path)` to find its download location. + * use `SparkFiles.get(fileName)` to find its download location. */ def addFile(path: String) { sc.addFile(path) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 747023812f754..ae8010300a500 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -749,6 +749,23 @@ private[spark] object PythonRDD extends Logging { } } } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]] + } else { + Seq(obj) + } + } + }.toJavaRDD() + } } private diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 52c70712eea3d..be5ebfa9219d3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -40,28 +40,3 @@ private[spark] object PythonUtils { paths.filter(_ != "").mkString(File.pathSeparator) } } - - -/** - * A utility class to redirect the child process's stdout or stderr. - */ -private[spark] class RedirectThread( - in: InputStream, - out: OutputStream, - name: String) - extends Thread(name) { - - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - out.write(buf, 0, len) - out.flush() - len = in.read(buf) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index bf716a8ab025b..4c4796f6c59ba 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,6 @@ package org.apache.spark.api.python -import java.lang.Runtime import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} @@ -25,7 +24,7 @@ import scala.collection.mutable import scala.collection.JavaConversions._ import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{RedirectThread, Utils} private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 6173fd3a69fc7..42d58682a1e23 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -28,6 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.io.ByteArrayChunkOutputStream /** * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. @@ -201,29 +202,12 @@ private object TorrentBroadcast extends Logging { } def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = { - // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks - // so we don't need to do the extra memory copy. - val bos = new ByteArrayOutputStream() + val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE) val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() - val byteArray = bos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt - val blocks = new Array[ByteBuffer](numBlocks) - - var blockId = 0 - for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { - val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - val tempByteArray = new Array[Byte](thisBlockSize) - bais.read(tempByteArray, 0, thisBlockSize) - - blocks(blockId) = ByteBuffer.wrap(tempByteArray) - blockId += 1 - } - bais.close() - blocks + bos.toArrays.map(ByteBuffer.wrap) } def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = { diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 0d6751f3fa6d2..b66c3ba4d5fb0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -22,8 +22,8 @@ import java.net.URI import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ -import org.apache.spark.api.python.{PythonUtils, RedirectThread} -import org.apache.spark.util.Utils +import org.apache.spark.api.python.PythonUtils +import org.apache.spark.util.{RedirectThread, Utils} /** * A main class used by spark-submit to launch Python applications. It executes python as a diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 148115d3ed351..fe0ad9ebbca12 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -24,15 +24,18 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} +import org.apache.spark.annotation.DeveloperApi import scala.collection.JavaConversions._ /** + * :: DeveloperApi :: * Contains util methods to interact with Hadoop from Spark. */ +@DeveloperApi class SparkHadoopUtil extends Logging { - val conf: Configuration = newConfiguration() + val conf: Configuration = newConfiguration(new SparkConf()) UserGroupInformation.setConfiguration(conf) /** @@ -64,11 +67,39 @@ class SparkHadoopUtil extends Logging { } } + @Deprecated + def newConfiguration(): Configuration = newConfiguration(null) + /** * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop * subsystems. */ - def newConfiguration(): Configuration = new Configuration() + def newConfiguration(conf: SparkConf): Configuration = { + val hadoopConf = new Configuration() + + // Note: this null check is around more than just access to the "conf" object to maintain + // the behavior of the old implementation of this code, for backwards compatibility. + if (conf != null) { + // Explicitly check for S3 environment variables + if (System.getenv("AWS_ACCESS_KEY_ID") != null && + System.getenv("AWS_SECRET_ACCESS_KEY") != null) { + hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + } + // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" + conf.getAll.foreach { case (key, value) => + if (key.startsWith("spark.hadoop.")) { + hadoopConf.set(key.substring("spark.hadoop.".length), value) + } + } + val bufferSize = conf.get("spark.buffer.size", "65536") + hadoopConf.set("io.file.buffer.size", bufferSize) + } + + hadoopConf + } /** * Add any user credentials to the job conf which are necessary for running on a secure Hadoop @@ -86,7 +117,7 @@ class SparkHadoopUtil extends Logging { def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } - def loginUserFromKeytab(principalName: String, keytabFilename: String) { + def loginUserFromKeytab(principalName: String, keytabFilename: String) { UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 318509a67a36f..0fdb5ae3c2e40 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -54,6 +54,8 @@ object SparkSubmit { private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + private val CLASS_NOT_FOUND_EXIT_STATUS = 1 + // Exposed for testing private[spark] var exitFn: () => Unit = () => System.exit(-1) private[spark] var printStream: PrintStream = System.err @@ -171,6 +173,14 @@ object SparkSubmit { OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, + sysProp = "spark.driver.memory"), + OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.driver.extraClassPath"), + OptionAssigner(args.driverExtraJavaOptions, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.driver.extraJavaOptions"), + OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.driver.extraLibraryPath"), // Standalone cluster only OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), @@ -195,12 +205,6 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), // Other options - OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, CLUSTER, - sysProp = "spark.driver.extraClassPath"), - OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, CLUSTER, - sysProp = "spark.driver.extraJavaOptions"), - OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, CLUSTER, - sysProp = "spark.driver.extraLibraryPath"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, @@ -308,8 +312,18 @@ object SparkSubmit { System.setProperty(key, value) } - val mainClass = Class.forName(childMainClass, true, loader) + var mainClass: Class[_] = null + + try { + mainClass = Class.forName(childMainClass, true, loader) + } catch { + case e: ClassNotFoundException => + e.printStackTrace(printStream) + System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + } + val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) + try { mainMethod.invoke(null, childArgs.toArray) } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala new file mode 100644 index 0000000000000..38b5d8e1739d0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File + +import scala.collection.JavaConversions._ + +import org.apache.spark.util.{RedirectThread, Utils} + +/** + * Launch an application through Spark submit in client mode with the appropriate classpath, + * library paths, java options and memory. These properties of the JVM must be set before the + * driver JVM is launched. The sole purpose of this class is to avoid handling the complexity + * of parsing the properties file for such relevant configs in Bash. + * + * Usage: org.apache.spark.deploy.SparkSubmitDriverBootstrapper + */ +private[spark] object SparkSubmitDriverBootstrapper { + + // Note: This class depends on the behavior of `bin/spark-class` and `bin/spark-submit`. + // Any changes made there must be reflected in this file. + + def main(args: Array[String]): Unit = { + + // This should be called only from `bin/spark-class` + if (!sys.env.contains("SPARK_CLASS")) { + System.err.println("SparkSubmitDriverBootstrapper must be called from `bin/spark-class`!") + System.exit(1) + } + + val submitArgs = args + val runner = sys.env("RUNNER") + val classpath = sys.env("CLASSPATH") + val javaOpts = sys.env("JAVA_OPTS") + val defaultDriverMemory = sys.env("OUR_JAVA_MEM") + + // Spark submit specific environment variables + val deployMode = sys.env("SPARK_SUBMIT_DEPLOY_MODE") + val propertiesFile = sys.env("SPARK_SUBMIT_PROPERTIES_FILE") + val bootstrapDriver = sys.env("SPARK_SUBMIT_BOOTSTRAP_DRIVER") + val submitDriverMemory = sys.env.get("SPARK_SUBMIT_DRIVER_MEMORY") + val submitLibraryPath = sys.env.get("SPARK_SUBMIT_LIBRARY_PATH") + val submitClasspath = sys.env.get("SPARK_SUBMIT_CLASSPATH") + val submitJavaOpts = sys.env.get("SPARK_SUBMIT_OPTS") + + assume(runner != null, "RUNNER must be set") + assume(classpath != null, "CLASSPATH must be set") + assume(javaOpts != null, "JAVA_OPTS must be set") + assume(defaultDriverMemory != null, "OUR_JAVA_MEM must be set") + assume(deployMode == "client", "SPARK_SUBMIT_DEPLOY_MODE must be \"client\"!") + assume(propertiesFile != null, "SPARK_SUBMIT_PROPERTIES_FILE must be set") + assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set") + + // Parse the properties file for the equivalent spark.driver.* configs + val properties = SparkSubmitArguments.getPropertiesFromFile(new File(propertiesFile)).toMap + val confDriverMemory = properties.get("spark.driver.memory") + val confLibraryPath = properties.get("spark.driver.extraLibraryPath") + val confClasspath = properties.get("spark.driver.extraClassPath") + val confJavaOpts = properties.get("spark.driver.extraJavaOptions") + + // Favor Spark submit arguments over the equivalent configs in the properties file. + // Note that we do not actually use the Spark submit values for library path, classpath, + // and Java opts here, because we have already captured them in Bash. + + val newDriverMemory = submitDriverMemory + .orElse(confDriverMemory) + .getOrElse(defaultDriverMemory) + + val newLibraryPath = + if (submitLibraryPath.isDefined) { + // SPARK_SUBMIT_LIBRARY_PATH is already captured in JAVA_OPTS + "" + } else { + confLibraryPath.map("-Djava.library.path=" + _).getOrElse("") + } + + val newClasspath = + if (submitClasspath.isDefined) { + // SPARK_SUBMIT_CLASSPATH is already captured in CLASSPATH + classpath + } else { + classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("") + } + + val newJavaOpts = + if (submitJavaOpts.isDefined) { + // SPARK_SUBMIT_OPTS is already captured in JAVA_OPTS + javaOpts + } else { + javaOpts + confJavaOpts.map(" " + _).getOrElse("") + } + + val filteredJavaOpts = Utils.splitCommandString(newJavaOpts) + .filterNot(_.startsWith("-Xms")) + .filterNot(_.startsWith("-Xmx")) + + // Build up command + val command: Seq[String] = + Seq(runner) ++ + Seq("-cp", newClasspath) ++ + Seq(newLibraryPath) ++ + filteredJavaOpts ++ + Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++ + Seq("org.apache.spark.deploy.SparkSubmit") ++ + submitArgs + + // Print the launch command. This follows closely the format used in `bin/spark-class`. + if (sys.env.contains("SPARK_PRINT_LAUNCH_COMMAND")) { + System.err.print("Spark Command: ") + System.err.println(command.mkString(" ")) + System.err.println("========================================\n") + } + + // Start the driver JVM + val filteredCommand = command.filter(_.nonEmpty) + val builder = new ProcessBuilder(filteredCommand) + val process = builder.start() + + // Redirect stdout and stderr from the child JVM + val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout") + val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr") + stdoutThread.start() + stderrThread.start() + + // Redirect stdin to child JVM only if we're not running Windows. This is because the + // subprocess there already reads directly from our stdin, so we should avoid spawning a + // thread that contends with the subprocess in reading from System.in. + val isWindows = Utils.isWindows + val isPySparkShell = sys.env.contains("PYSPARK_SHELL") + if (!isWindows) { + val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin") + stdinThread.start() + // For the PySpark shell, Spark submit itself runs as a python subprocess, and so this JVM + // should terminate on broken pipe, which signals that the parent process has exited. In + // Windows, the termination logic for the PySpark shell is handled in java_gateway.py + if (isPySparkShell) { + stdinThread.join() + process.destroy() + } + } + process.waitFor() + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index a0e8bd403a41d..fbe39b27649f6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -34,15 +34,15 @@ private[spark] abstract class ApplicationHistoryProvider { * * @return List of all know applications. */ - def getListing(): Seq[ApplicationHistoryInfo] + def getListing(): Iterable[ApplicationHistoryInfo] /** * Returns the Spark UI for a specific application. * * @param appId The application ID. - * @return The application's UI, or null if application is not found. + * @return The application's UI, or None if application is not found. */ - def getAppUI(appId: String): SparkUI + def getAppUI(appId: String): Option[SparkUI] /** * Called when the server is shutting down. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index cc06540ee0647..481f6c93c6a8d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.Utils @@ -31,6 +32,8 @@ import org.apache.spark.util.Utils private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider with Logging { + private val NOT_STARTED = "" + // Interval between each check for event log updates private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", conf.getInt("spark.history.updateInterval", 10)) * 1000 @@ -40,13 +43,21 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis .map { d => Utils.resolveURI(d) } .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") } - private val fs = Utils.getHadoopFileSystem(resolvedLogDir) + private val fs = Utils.getHadoopFileSystem(resolvedLogDir, + SparkHadoopUtil.get.newConfiguration(conf)) // A timestamp of when the disk was last accessed to check for log updates private var lastLogCheckTimeMs = -1L - // List of applications, in order from newest to oldest. - @volatile private var appList: Seq[ApplicationHistoryInfo] = Nil + // The modification time of the newest log detected during the last scan. This is used + // to ignore logs that are older during subsequent scans, to avoid processing data that + // is already known. + private var lastModifiedTime = -1L + + // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted + // into the map in order, so the LinkedHashMap maintains the correct ordering. + @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] + = new mutable.LinkedHashMap() /** * A background thread that periodically checks for event log updates on disk. @@ -91,15 +102,35 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis logCheckingThread.start() } - override def getListing() = appList + override def getListing() = applications.values - override def getAppUI(appId: String): SparkUI = { + override def getAppUI(appId: String): Option[SparkUI] = { try { - val appLogDir = fs.getFileStatus(new Path(resolvedLogDir.toString, appId)) - val (_, ui) = loadAppInfo(appLogDir, renderUI = true) - ui + applications.get(appId).map { info => + val (replayBus, appListener) = createReplayBus(fs.getFileStatus( + new Path(logDir, info.logDir))) + val ui = { + val conf = this.conf.clone() + val appSecManager = new SecurityManager(conf) + new SparkUI(conf, appSecManager, replayBus, appId, + s"${HistoryServer.UI_PATH_PREFIX}/$appId") + // Do not call ui.bind() to avoid creating a new server for each application + } + + replayBus.replay() + + ui.setAppName(s"${appListener.appName.getOrElse(NOT_STARTED)} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(appListener.sparkUser.getOrElse(NOT_STARTED), + appListener.viewAcls.getOrElse("")) + ui + } } catch { - case e: FileNotFoundException => null + case e: FileNotFoundException => None } } @@ -117,84 +148,79 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis try { val logStatus = fs.listStatus(new Path(resolvedLogDir)) val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() - val logInfos = logDirs.filter { dir => - fs.isFile(new Path(dir.getPath, EventLoggingListener.APPLICATION_COMPLETE)) - } - val currentApps = Map[String, ApplicationHistoryInfo]( - appList.map(app => app.id -> app):_*) - - // For any application that either (i) is not listed or (ii) has changed since the last time - // the listing was created (defined by the log dir's modification time), load the app's info. - // Otherwise just reuse what's already in memory. - val newApps = new mutable.ArrayBuffer[ApplicationHistoryInfo](logInfos.size) - for (dir <- logInfos) { - val curr = currentApps.getOrElse(dir.getPath().getName(), null) - if (curr == null || curr.lastUpdated < getModificationTime(dir)) { + // Load all new logs from the log directory. Only directories that have a modification time + // later than the last known log directory will be loaded. + var newLastModifiedTime = lastModifiedTime + val logInfos = logDirs + .filter { dir => + if (fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE))) { + val modTime = getModificationTime(dir) + newLastModifiedTime = math.max(newLastModifiedTime, modTime) + modTime > lastModifiedTime + } else { + false + } + } + .flatMap { dir => try { - val (app, _) = loadAppInfo(dir, renderUI = false) - newApps += app + val (replayBus, appListener) = createReplayBus(dir) + replayBus.replay() + Some(new FsApplicationHistoryInfo( + dir.getPath().getName(), + appListener.appId.getOrElse(dir.getPath().getName()), + appListener.appName.getOrElse(NOT_STARTED), + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(dir), + appListener.sparkUser.getOrElse(NOT_STARTED))) } catch { - case e: Exception => logError(s"Failed to load app info from directory $dir.") + case e: Exception => + logInfo(s"Failed to load application log data from $dir.", e) + None + } + } + .sortBy { info => -info.endTime } + + lastModifiedTime = newLastModifiedTime + + // When there are new logs, merge the new list with the existing one, maintaining + // the expected ordering (descending end time). Maintaining the order is important + // to avoid having to sort the list every time there is a request for the log list. + if (!logInfos.isEmpty) { + val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() + def addIfAbsent(info: FsApplicationHistoryInfo) = { + if (!newApps.contains(info.id)) { + newApps += (info.id -> info) } - } else { - newApps += curr } - } - appList = newApps.sortBy { info => -info.endTime } + val newIterator = logInfos.iterator.buffered + val oldIterator = applications.values.iterator.buffered + while (newIterator.hasNext && oldIterator.hasNext) { + if (newIterator.head.endTime > oldIterator.head.endTime) { + addIfAbsent(newIterator.next) + } else { + addIfAbsent(oldIterator.next) + } + } + newIterator.foreach(addIfAbsent) + oldIterator.foreach(addIfAbsent) + + applications = newApps + } } catch { case t: Throwable => logError("Exception in checking for event log updates", t) } } - /** - * Parse the application's logs to find out the information we need to build the - * listing page. - * - * When creating the listing of available apps, there is no need to load the whole UI for the - * application. The UI is requested by the HistoryServer (by calling getAppInfo()) when the user - * clicks on a specific application. - * - * @param logDir Directory with application's log files. - * @param renderUI Whether to create the SparkUI for the application. - * @return A 2-tuple `(app info, ui)`. `ui` will be null if `renderUI` is false. - */ - private def loadAppInfo(logDir: FileStatus, renderUI: Boolean) = { - val path = logDir.getPath - val appId = path.getName + private def createReplayBus(logDir: FileStatus): (ReplayListenerBus, ApplicationEventListener) = { + val path = logDir.getPath() val elogInfo = EventLoggingListener.parseLoggingInfo(path, fs) val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec) val appListener = new ApplicationEventListener replayBus.addListener(appListener) - - val ui: SparkUI = if (renderUI) { - val conf = this.conf.clone() - val appSecManager = new SecurityManager(conf) - new SparkUI(conf, appSecManager, replayBus, appId, - HistoryServer.UI_PATH_PREFIX + s"/$appId") - // Do not call ui.bind() to avoid creating a new server for each application - } else { - null - } - - replayBus.replay() - val appInfo = ApplicationHistoryInfo( - appId, - appListener.appName, - appListener.startTime, - appListener.endTime, - getModificationTime(logDir), - appListener.sparkUser) - - if (ui != null) { - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls) - ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) - } - (appInfo, ui) + (replayBus, appListener) } /** Return when this directory was last modified. */ @@ -217,3 +243,13 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private def getMonotonicTimeMs() = System.nanoTime() / (1000 * 1000) } + +private class FsApplicationHistoryInfo( + val logDir: String, + id: String, + name: String, + startTime: Long, + endTime: Long, + lastUpdated: Long, + sparkUser: String) + extends ApplicationHistoryInfo(id, name, startTime, endTime, lastUpdated, sparkUser) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d1a64c1912cb8..ce00c0ffd21e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -52,10 +52,7 @@ class HistoryServer( private val appLoader = new CacheLoader[String, SparkUI] { override def load(key: String): SparkUI = { - val ui = provider.getAppUI(key) - if (ui == null) { - throw new NoSuchElementException() - } + val ui = provider.getAppUI(key).getOrElse(throw new NoSuchElementException()) attachSparkUI(ui) ui } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index d3674427b1271..c3ca43f8d0734 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -96,11 +96,13 @@ private[spark] class ApplicationInfo( def retryCount = _retryCount - def incrementRetryCount = { + def incrementRetryCount() = { _retryCount += 1 _retryCount } + def resetRetryCount() = _retryCount = 0 + def markFinished(endState: ApplicationState.Value) { state = endState endTime = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala index 33377931d6993..80b570a44af18 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -33,4 +33,17 @@ private[spark] class DriverInfo( @transient var exception: Option[Exception] = None /* Most recent worker assigned to this driver */ @transient var worker: Option[WorkerInfo] = None + + init() + + private def readObject(in: java.io.ObjectInputStream): Unit = { + in.defaultReadObject() + init() + } + + private def init(): Unit = { + state = DriverState.SUBMITTED + worker = None + exception = None + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 5017273e87c07..2a3bd6ba0b9dc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -33,7 +33,8 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, + SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState @@ -295,28 +296,34 @@ private[spark] class Master( val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) execOption match { case Some(exec) => { + val appInfo = idToApp(appId) exec.state = state + if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) if (ExecutorState.isFinished(state)) { - val appInfo = idToApp(appId) // Remove this executor from the worker and app - logInfo("Removing executor " + exec.fullId + " because it is " + state) + logInfo(s"Removing executor ${exec.fullId} because it is $state") appInfo.removeExecutor(exec) exec.worker.removeExecutor(exec) - val normalExit = exitStatus.exists(_ == 0) + val normalExit = exitStatus == Some(0) // Only retry certain number of times so we don't go into an infinite loop. - if (!normalExit && appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { - schedule() - } else if (!normalExit) { - logError("Application %s with ID %s failed %d times, removing it".format( - appInfo.desc.name, appInfo.id, appInfo.retryCount)) - removeApplication(appInfo, ApplicationState.FAILED) + if (!normalExit) { + if (appInfo.incrementRetryCount() < ApplicationState.MAX_NUM_RETRY) { + schedule() + } else { + val execs = appInfo.executors.values + if (!execs.exists(_.state == ExecutorState.RUNNING)) { + logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " + + s"${appInfo.retryCount} times; removing it") + removeApplication(appInfo, ApplicationState.FAILED) + } + } } } } case None => - logWarning("Got status update for unknown executor " + appId + "/" + execId) + logWarning(s"Got status update for unknown executor $appId/$execId") } } @@ -480,13 +487,25 @@ private[spark] class Master( if (state != RecoveryState.ALIVE) { return } // First schedule drivers, they take strict precedence over applications - val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers - for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { - for (driver <- List(waitingDrivers: _*)) { // iterate over a copy of waitingDrivers + // Randomization helps balance drivers + val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE)) + val aliveWorkerNum = shuffledAliveWorkers.size + var curPos = 0 + for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers + // We assign workers to each waiting driver in a round-robin fashion. For each driver, we + // start from the last worker that was assigned a driver, and continue onwards until we have + // explored all alive workers. + curPos = (curPos + 1) % aliveWorkerNum + val startPos = curPos + var launched = false + while (curPos != startPos && !launched) { + val worker = shuffledAliveWorkers(curPos) if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { launchDriver(worker, driver) waitingDrivers -= driver + launched = true } + curPos = (curPos + 1) % aliveWorkerNum } } @@ -673,7 +692,8 @@ private[spark] class Master( app.desc.appUiUrl = notFoundBasePath return false } - val fileSystem = Utils.getHadoopFileSystem(eventLogDir) + val fileSystem = Utils.getHadoopFileSystem(eventLogDir, + SparkHadoopUtil.get.newConfiguration(conf)) val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem) val eventLogPaths = eventLogInfo.logPaths val compressionCodec = eventLogInfo.compressionCodec diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 687e492a0d6fc..12e98fd40d6c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -64,8 +64,6 @@ object CommandUtils extends Logging { Seq() } - val permGenOpt = Seq("-XX:MaxPermSize=128m") - // Figure out our classpath with the external compute-classpath script val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" val classPath = Utils.executeAndGetOutput( @@ -73,6 +71,8 @@ object CommandUtils extends Logging { extraEnvironment = command.environment) val userClassPath = command.classPathEntries ++ Seq(classPath) + val javaVersion = System.getProperty("java.version") + val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 5caaf6bea3575..9f9911762505a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -28,8 +28,8 @@ import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileUtil, Path} -import org.apache.spark.Logging -import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.{Command, DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState @@ -39,6 +39,7 @@ import org.apache.spark.deploy.master.DriverState.DriverState * This is currently only used in standalone cluster deploy mode. */ private[spark] class DriverRunner( + val conf: SparkConf, val driverId: String, val workDir: File, val sparkHome: File, @@ -144,8 +145,8 @@ private[spark] class DriverRunner( val jarPath = new Path(driverDesc.jarUrl) - val emptyConf = new Configuration() - val jarFileSystem = jarPath.getFileSystem(emptyConf) + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + val jarFileSystem = jarPath.getFileSystem(hadoopConf) val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) val jarFileName = jarPath.getName @@ -154,7 +155,7 @@ private[spark] class DriverRunner( if (!localJarFile.exists()) { // May already exist if running multiple workers on one node logInfo(s"Copying user jar $jarPath to $destPath") - FileUtil.copy(jarFileSystem, jarPath, destPath, false, emptyConf) + FileUtil.copy(jarFileSystem, jarPath, destPath, false, hadoopConf) } if (!localJarFile.exists()) { // Verify copy succeeded diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 7be89f9aff0f3..00a43673e5cd3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -159,6 +159,8 @@ private[spark] class ExecutorRunner( Files.write(header, stderr, Charsets.UTF_8) stderrAppender = FileAppender(process.getErrorStream, stderr, conf) + state = ExecutorState.RUNNING + worker ! ExecutorStateChanged(appId, execId, state, None, None) // Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown) // or with nonzero exit code val exitCode = process.waitFor() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 81400af22c0bf..0c454e4138c96 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -234,7 +234,7 @@ private[spark] class Worker( try { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, sparkHome, workDir, akkaUrl, conf, ExecutorState.RUNNING) + self, workerId, host, sparkHome, workDir, akkaUrl, conf, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -257,7 +257,7 @@ private[spark] class Worker( val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { - case Some(executor) => + case Some(executor) => logInfo("Executor " + fullId + " finished with state " + state + message.map(" message " + _).getOrElse("") + exitStatus.map(" exitStatus " + _).getOrElse("")) @@ -288,7 +288,7 @@ private[spark] class Worker( case LaunchDriver(driverId, driverDesc) => { logInfo(s"Asked to launch driver $driverId") - val driver = new DriverRunner(driverId, workDir, sparkHome, driverDesc, self, akkaUrl) + val driver = new DriverRunner(conf, driverId, workDir, sparkHome, driverDesc, self, akkaUrl) drivers(driverId) = driver driver.start() diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2f76e532aeb76..acae448a9c66f 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,6 +26,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} @@ -122,6 +123,9 @@ private[spark] class Executor( env.metricsSystem.report() isStopped = true threadPool.shutdown() + if (!isLocal) { + env.stop() + } } class TaskRunner( @@ -294,9 +298,9 @@ private[spark] class Executor( try { val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader], - classOf[Boolean]) - constructor.newInstance(classUri, parent, userClassPathFirst) + val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], + classOf[ClassLoader], classOf[Boolean]) + constructor.newInstance(conf, classUri, parent, userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -313,16 +317,19 @@ private[spark] class Executor( * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) synchronized { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager, + hadoopConf) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager, + hadoopConf) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last @@ -353,7 +360,16 @@ private[spark] class Executor( if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => metrics.updateShuffleReadMetrics - tasksMetrics += ((taskRunner.taskId, metrics)) + if (isLocal) { + // JobProgressListener will hold an reference of it during + // onExecutorMetricsUpdate(), then JobProgressListener can not see + // the changes of metrics any more, so make a deep copy of it + val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) + tasksMetrics += ((taskRunner.taskId, copiedMetrics)) + } else { + // It will be copied by serialization + tasksMetrics += ((taskRunner.taskId, metrics)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ef9c43ecf14f6..1ac7f4e448eb1 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -64,6 +64,7 @@ private[spark] object CompressionCodec { } val DEFAULT_COMPRESSION_CODEC = "snappy" + val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq } diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala similarity index 56% rename from core/src/main/scala/org/apache/spark/network/ReceiverTest.scala rename to core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 53a6038a9b59e..e0e91724271c8 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,21 +17,20 @@ package org.apache.spark.network -import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.storage.StorageLevel -private[spark] object ReceiverTest { - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - println("Started connection manager with id = " + manager.id) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */ - val buffer = ByteBuffer.wrap("response".getBytes("utf-8")) - Some(Message.createBufferMessage(buffer, msg.id)) - }) - Thread.currentThread.join() - } -} +trait BlockDataManager { + + /** + * Interface to get local block data. + * + * @return Some(buffer) if the block exists locally, and None if it doesn't. + */ + def getBlockData(blockId: String): Option[ManagedBuffer] + /** + * Put the block locally, using the given storage level. + */ + def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala new file mode 100644 index 0000000000000..34acaa563ca58 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.util.EventListener + + +/** + * Listener callback interface for [[BlockTransferService.fetchBlocks]]. + */ +trait BlockFetchingListener extends EventListener { + + /** + * Called once per successfully fetched block. + */ + def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit + + /** + * Called upon failures. For each failure, this is called only once (i.e. not once per block). + */ + def onBlockFetchFailure(exception: Throwable): Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala new file mode 100644 index 0000000000000..84d991fa6808c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.Duration + +import org.apache.spark.storage.StorageLevel + + +abstract class BlockTransferService { + + /** + * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch + * local blocks or put local blocks. + */ + def init(blockDataManager: BlockDataManager) + + /** + * Tear down the transfer service. + */ + def stop(): Unit + + /** + * Port number the service is listening on, available only after [[init]] is invoked. + */ + def port: Int + + /** + * Host name the service is listening on, available only after [[init]] is invoked. + */ + def hostName: String + + /** + * Fetch a sequence of blocks from a remote node asynchronously, + * available only after [[init]] is invoked. + * + * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block, + * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block). + * + * Note that this API takes a sequence so the implementation can batch requests, and does not + * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as + * the data of a block is fetched, rather than waiting for all blocks to be fetched. + */ + def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit + + /** + * Upload a single block to a remote node, available only after [[init]] is invoked. + */ + def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel): Future[Unit] + + /** + * A special case of [[fetchBlocks]], as it fetches only one block and is blocking. + * + * It is also only available after [[init]] is invoked. + */ + def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = { + // A monitor for the thread to wait on. + val lock = new Object + @volatile var result: Either[ManagedBuffer, Throwable] = null + fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { + lock.synchronized { + result = Right(exception) + lock.notify() + } + } + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + lock.synchronized { + result = Left(data) + lock.notify() + } + } + }) + + // Sleep until result is no longer null + lock.synchronized { + while (result == null) { + try { + lock.wait() + } catch { + case e: InterruptedException => + } + } + } + + result match { + case Left(data) => data + case Right(e) => throw e + } + } + + /** + * Upload a single block to a remote node, available only after [[init]] is invoked. + * + * This method is similar to [[uploadBlock]], except this one blocks the thread + * until the upload finishes. + */ + def uploadBlockSync( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel): Unit = { + Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala deleted file mode 100644 index 4894ecd41f6eb..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer - -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.io.Source - -import org.apache.spark._ - -private[spark] object ConnectionManagerTest extends Logging{ - def main(args: Array[String]) { - // - the master URL - a list slaves to run connectionTest on - // [num of tasks] - the number of parallel tasks to be initiated default is number of slave - // hosts [size of msg in MB (integer)] - the size of messages to be sent in each task, - // default is 10 [count] - how many times to run, default is 3 [await time in seconds] : - // await time (in seconds), default is 600 - if (args.length < 2) { - println("Usage: ConnectionManagerTest [num of tasks] " + - "[size of msg in MB (integer)] [count] [await time in seconds)] ") - System.exit(1) - } - - if (args(0).startsWith("local")) { - println("This runs only on a mesos cluster") - } - - val sc = new SparkContext(args(0), "ConnectionManagerTest") - val slavesFile = Source.fromFile(args(1)) - val slaves = slavesFile.mkString.split("\n") - slavesFile.close() - - /* println("Slaves") */ - /* slaves.foreach(println) */ - val tasknum = if (args.length > 2) args(2).toInt else slaves.length - val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 - val count = if (args.length > 4) args(4).toInt else 3 - val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second - println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " + - "msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) - val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( - i => SparkEnv.get.connectionManager.id).collect() - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - logInfo("Received [" + msg + "] from [" + id + "]") - None - }) - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId => - { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - } - } - val results = futures.map(f => Await.result(f, awaitTime)) - val finishTime = System.currentTimeMillis - Thread.sleep(5000) - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * - 1000.0) + " MB/s" - logInfo(resultStr) - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } -} - diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala new file mode 100644 index 0000000000000..dcecb6beeea9b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.io.{FileInputStream, RandomAccessFile, File, InputStream} +import java.nio.ByteBuffer +import java.nio.channels.FileChannel.MapMode + +import com.google.common.io.ByteStreams +import io.netty.buffer.{ByteBufInputStream, ByteBuf} + +import org.apache.spark.util.ByteBufferInputStream + + +/** + * This interface provides an immutable view for data in the form of bytes. The implementation + * should specify how the data is provided: + * + * - FileSegmentManagedBuffer: data backed by part of a file + * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer + * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf + */ +sealed abstract class ManagedBuffer { + // Note that all the methods are defined with parenthesis because their implementations can + // have side effects (io operations). + + /** Number of bytes of the data. */ + def size: Long + + /** + * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the + * returned ByteBuffer should not affect the content of this buffer. + */ + def nioByteBuffer(): ByteBuffer + + /** + * Exposes this buffer's data as an InputStream. The underlying implementation does not + * necessarily check for the length of bytes read, so the caller is responsible for making sure + * it does not go over the limit. + */ + def inputStream(): InputStream +} + + +/** + * A [[ManagedBuffer]] backed by a segment in a file + */ +final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) + extends ManagedBuffer { + + override def size: Long = length + + override def nioByteBuffer(): ByteBuffer = { + val channel = new RandomAccessFile(file, "r").getChannel + channel.map(MapMode.READ_ONLY, offset, length) + } + + override def inputStream(): InputStream = { + val is = new FileInputStream(file) + is.skip(offset) + ByteStreams.limit(is, length) + } +} + + +/** + * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. + */ +final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { + + override def size: Long = buf.remaining() + + override def nioByteBuffer() = buf.duplicate() + + override def inputStream() = new ByteBufferInputStream(buf) +} + + +/** + * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. + */ +final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { + + override def size: Long = buf.readableBytes() + + override def nioByteBuffer() = buf.nioBuffer() + + override def inputStream() = new ByteBufInputStream(buf) + + // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. + def release(): Unit = buf.release() +} diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala deleted file mode 100644 index ea2ad104ecae1..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} - -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.Try - -private[spark] object SenderTest { - def main(args: Array[String]) { - - if (args.length < 2) { - println("Usage: SenderTest ") - System.exit(1) - } - - val targetHost = args(0) - val targetPort = args(1).toInt - val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - val conf = new SparkConf - val manager = new ConnectionManager(0, conf, new SecurityManager(conf)) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - val size = 100 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val targetServer = args(0) - - val count = 100 - (0 until count).foreach(i => { - val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis - /* println("Started timer at " + startTime) */ - val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage) - val responseStr: String = Try(Await.result(promise, Duration.Inf)) - .map { response => - val buffer = response.asInstanceOf[BufferMessage].buffers(0) - new String(buffer.array, "utf-8") - }.getOrElse("none") - - val finishTime = System.currentTimeMillis - val mb = size / 1024.0 / 1024.0 - val ms = finishTime - startTime - // val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms - // * 1000.0) + " MB/s" - val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + - (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr - println(resultStr) - }) - } -} - diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala similarity index 89% rename from core/src/main/scala/org/apache/spark/storage/BlockMessage.scala rename to core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index a2bfce7b4a0fa..b573f1a8a5fcb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -15,20 +15,20 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.StringBuilder +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} -import org.apache.spark.network._ +import scala.collection.mutable.{ArrayBuffer, StringBuilder} +// private[spark] because we need to register them in Kryo private[spark] case class GetBlock(id: BlockId) private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) -private[spark] class BlockMessage() { +private[nio] class BlockMessage() { // Un-initialized: typ = 0 // GetBlock: typ = 1 // GotBlock: typ = 2 @@ -159,7 +159,7 @@ private[spark] class BlockMessage() { } } -private[spark] object BlockMessage { +private[nio] object BlockMessage { val TYPE_NON_INITIALIZED: Int = 0 val TYPE_GET_BLOCK: Int = 1 val TYPE_GOT_BLOCK: Int = 2 @@ -194,16 +194,4 @@ private[spark] object BlockMessage { newBlockMessage.set(putBlock) newBlockMessage } - - def main(args: Array[String]) { - val B = new BlockMessage() - val blockId = TestBlockId("ABC") - B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) - val bMsg = B.toBufferMessage - val C = new BlockMessage() - C.set(bMsg) - - println(B.getId + " " + B.getLevel) - println(C.getId + " " + C.getLevel) - } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala rename to core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 973d85c0a9b3a..a1a2c00ed1542 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - import org.apache.spark._ -import org.apache.spark.network._ +import org.apache.spark.storage.{StorageLevel, TestBlockId} + +import scala.collection.mutable.ArrayBuffer -private[spark] +private[nio] class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { @@ -102,7 +102,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) } } -private[spark] object BlockMessageArray { +private[nio] object BlockMessageArray { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/network/BufferMessage.scala rename to core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala index af35f1fc3e459..3b245c5c7a4f3 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio import java.nio.ByteBuffer @@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.storage.BlockManager -private[spark] + +private[nio] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) extends Message(Message.BUFFER_MESSAGE, id_) { diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/Connection.scala rename to core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 5285ec82c1b64..74074a8dcbfff 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -15,17 +15,17 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ -import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} - import org.apache.spark._ -private[spark] +import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} + +private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) extends Logging { @@ -190,7 +190,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, } -private[spark] +private[nio] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, remoteId_ : ConnectionManagerId, id_ : ConnectionId) extends Connection(SocketChannel.open, selector_, remoteId_, id_) { diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala similarity index 88% rename from core/src/main/scala/org/apache/spark/network/ConnectionId.scala rename to core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala index d579c165a1917..764dc5e5503ed 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio -private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { +private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId } -private[spark] object ConnectionId { +private[nio] object ConnectionId { def createConnectionIdFromString(connectionIdString: String): ConnectionId = { val res = connectionIdString.split("_").map(_.trim()) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/network/ConnectionManager.scala rename to core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index b3e951ded6e77..09d3ea306515b 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -15,32 +15,27 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio import java.io.IOException +import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ -import java.net._ -import java.util.{Timer, TimerTask} import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} +import java.util.{Timer, TimerTask} -import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.SynchronizedMap -import scala.collection.mutable.SynchronizedQueue - -import scala.concurrent.{Await, ExecutionContext, Future, Promise} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.{SystemClock, Utils} -private[spark] class ConnectionManager( + +private[nio] class ConnectionManager( port: Int, conf: SparkConf, securityManager: SecurityManager, @@ -418,7 +413,7 @@ private[spark] class ConnectionManager( newConnection.onReceive(receiveMessage) addListeners(newConnection) addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") + logInfo("Accepted connection from [" + newConnection.remoteAddress + "]") } catch { // might happen in case of issues with registering with selector case e: Exception => logError("Error in accept loop", e) @@ -851,8 +846,8 @@ private[spark] class ConnectionManager( messageStatuses.synchronized { messageStatuses.remove(message.id).foreach ( s => { promise.failure( - new IOException(s"sendMessageReliably failed because ack " + - "was not received within ${ackTimeout} sec")) + new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec")) }) } } @@ -904,7 +899,7 @@ private[spark] class ConnectionManager( private[spark] object ConnectionManager { - import ExecutionContext.Implicits.global + import scala.concurrent.ExecutionContext.Implicits.global def main(args: Array[String]) { val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala similarity index 88% rename from core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala rename to core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala index 57f7586883af1..cbb37ec5ced1f 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio import java.net.InetSocketAddress import org.apache.spark.util.Utils -private[spark] case class ConnectionManagerId(host: String, port: Int) { +private[nio] case class ConnectionManagerId(host: String, port: Int) { // DEBUG code Utils.checkHost(host) assert (port > 0) @@ -30,7 +30,7 @@ private[spark] case class ConnectionManagerId(host: String, port: Int) { } -private[spark] object ConnectionManagerId { +private[nio] object ConnectionManagerId { def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala similarity index 95% rename from core/src/main/scala/org/apache/spark/network/Message.scala rename to core/src/main/scala/org/apache/spark/network/nio/Message.scala index 04ea50f62918c..0b874c2891255 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio import java.net.InetSocketAddress import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -private[spark] abstract class Message(val typ: Long, val id: Int) { + +private[nio] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null var started = false var startTime = -1L @@ -42,7 +43,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { } -private[spark] object Message { +private[nio] object Message { val BUFFER_MESSAGE = 1111111111L var lastId = 1 diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/network/MessageChunk.scala rename to core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala index d0f986a12bfe0..278c5ac356ef2 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -private[network] +private[nio] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { val size = if (buffer == null) 0 else buffer.remaining diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala similarity index 93% rename from core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala rename to core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala index f3ecca5f992e0..6e20f291c5cec 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala @@ -15,13 +15,12 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio -import java.net.InetAddress -import java.net.InetSocketAddress +import java.net.{InetAddress, InetSocketAddress} import java.nio.ByteBuffer -private[spark] class MessageChunkHeader( +private[nio] class MessageChunkHeader( val typ: Long, val id: Int, val totalSize: Int, @@ -57,7 +56,7 @@ private[spark] class MessageChunkHeader( } -private[spark] object MessageChunkHeader { +private[nio] object MessageChunkHeader { val HEADER_SIZE = 45 def create(buffer: ByteBuffer): MessageChunkHeader = { diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala new file mode 100644 index 0000000000000..59958ee894230 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.nio.ByteBuffer + +import scala.concurrent.Future + +import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} +import org.apache.spark.network._ +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.util.Utils + + +/** + * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom + * implementation using Java NIO. + */ +final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService with Logging { + + private var cm: ConnectionManager = _ + + private var blockDataManager: BlockDataManager = _ + + /** + * Port number the service is listening on, available only after [[init]] is invoked. + */ + override def port: Int = { + checkInit() + cm.id.port + } + + /** + * Host name the service is listening on, available only after [[init]] is invoked. + */ + override def hostName: String = { + checkInit() + cm.id.host + } + + /** + * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch + * local blocks or put local blocks. + */ + override def init(blockDataManager: BlockDataManager): Unit = { + this.blockDataManager = blockDataManager + cm = new ConnectionManager( + conf.getInt("spark.blockManager.port", 0), + conf, + securityManager, + "Connection manager for block manager") + cm.onReceiveMessage(onBlockMessageReceive) + } + + /** + * Tear down the transfer service. + */ + override def stop(): Unit = { + if (cm != null) { + cm.stop() + } + } + + override def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit = { + checkInit() + + val cmId = new ConnectionManagerId(hostName, port) + val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => + BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) + }) + + val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + + // Register the listener on success/failure future callback. + future.onSuccess { case message => + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + listener.onBlockFetchFailure( + new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + } else { + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + listener.onBlockFetchSuccess( + blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + } + } + }(cm.futureExecContext) + + future.onFailure { case exception => + listener.onBlockFetchFailure(exception) + }(cm.futureExecContext) + } + + /** + * Upload a single block to a remote node, available only after [[init]] is invoked. + * + * This call blocks until the upload completes, or throws an exception upon failures. + */ + override def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel) + : Future[Unit] = { + checkInit() + val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level) + val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) + val remoteCmId = new ConnectionManagerId(hostName, port) + val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) + reply.map(x => ())(cm.futureExecContext) + } + + private def checkInit(): Unit = if (cm == null) { + throw new IllegalStateException(getClass.getName + " has not been initialized") + } + + private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { + logDebug("Handling message " + msg) + msg match { + case bufferMessage: BufferMessage => + try { + logDebug("Handling as a buffer message " + bufferMessage) + val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) + logDebug("Parsed as a block message array") + val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) + Some(new BlockMessageArray(responseMessages).toBufferMessage) + } catch { + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } + } + + case otherMessage: Any => + logError("Unknown type message received: " + otherMessage) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } + } + + private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { + blockMessage.getType match { + case BlockMessage.TYPE_PUT_BLOCK => + val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) + logDebug("Received [" + msg + "]") + putBlock(msg.id.toString, msg.data, msg.level) + None + + case BlockMessage.TYPE_GET_BLOCK => + val msg = new GetBlock(blockMessage.getId) + logDebug("Received [" + msg + "]") + val buffer = getBlock(msg.id.toString) + if (buffer == null) { + return None + } + Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) + + case _ => None + } + } + + private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + val startTimeMs = System.currentTimeMillis() + logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) + blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " with data size: " + bytes.limit) + } + + private def getBlock(blockId: String): ByteBuffer = { + val startTimeMs = System.currentTimeMillis() + logDebug("GetBlock " + blockId + " started from " + startTimeMs) + val buffer = blockDataManager.getBlockData(blockId).orNull + logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " and got buffer " + buffer) + buffer.nioByteBuffer() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala similarity index 95% rename from core/src/main/scala/org/apache/spark/network/SecurityMessage.scala rename to core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala index 9af9e2e8e9e59..747a2088a7258 100644 --- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -15,15 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.StringBuilder +import scala.collection.mutable.{ArrayBuffer, StringBuilder} import org.apache.spark._ -import org.apache.spark.network._ /** * SecurityMessage is class that contains the connectionId and sasl token @@ -54,7 +52,7 @@ import org.apache.spark.network._ * - Length of the token * - Token */ -private[spark] class SecurityMessage() extends Logging { +private[nio] class SecurityMessage extends Logging { private var connectionId: String = null private var token: Array[Byte] = null @@ -134,7 +132,7 @@ private[spark] class SecurityMessage() extends Logging { } } -private[spark] object SecurityMessage { +private[nio] object SecurityMessage { /** * Convert the given BufferMessage to a SecurityMessage by parsing the contents diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 5cdbc306e56a0..e2fc9c649925e 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -44,4 +44,5 @@ package org.apache package object spark { // For package docs only + val SPARK_VERSION = "1.2.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index a74f80094434c..d5336284571d2 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -22,7 +22,7 @@ import cern.jet.stat.Probability import org.apache.spark.util.StatCounter /** - * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them + * An ApproximateEvaluator for sums. It estimates the mean and the count and multiplies them * together, then uses the formula for the variance of two independent random variables to get * a variance for the result and compute a confidence interval. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index aed951a40b40c..b62f3fbdc4a15 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -112,7 +112,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Applies a function f to all elements of this RDD. */ def foreachAsync(f: T => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size), + val cleanF = self.context.clean(f) + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), (index, data) => Unit, Unit) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 20938781ac694..7ba1182f0ed27 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -157,7 +157,7 @@ private[spark] object CheckpointRDD extends Logging { val sc = new SparkContext(cluster, "CheckpointRDD Test") val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") - val conf = SparkHadoopUtil.get.newConfiguration() + val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) val fs = path.getFileSystem(conf) val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index c8623314c98eb..036dcc49664ef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -42,7 +42,8 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataReadMethod, InputMetrics} import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.NextIterator +import org.apache.spark.util.{NextIterator, Utils} + /** * A Spark split class that wraps around a Hadoop InputSplit. @@ -228,7 +229,11 @@ class HadoopRDD[K, V]( try { reader.close() } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + case e: Exception => { + if (!Utils.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) + } + } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 58f707b9b4634..4c84b3f62354d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD +import org.apache.spark.util.Utils private[spark] class NewHadoopPartition( rddId: Int, @@ -153,7 +154,11 @@ class NewHadoopRDD[K, V]( try { reader.close() } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + case e: Exception => { + if (!Utils.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) + } + } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index e98bad2026e32..d0dbfef35d03c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Logging, RangePartitioner} +import org.apache.spark.{Logging, Partitioner, RangePartitioner} import org.apache.spark.annotation.DeveloperApi /** @@ -64,4 +64,16 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, new ShuffledRDD[K, V, V](self, part) .setKeyOrdering(if (ascending) ordering else ordering.reverse) } + + /** + * Repartition the RDD according to the given partitioner and, within each resulting partition, + * sort records by their keys. + * + * This is more efficient than calling `repartition` and then sorting within each partition + * because it can push the sorting down into the shuffle machinery. + */ + def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = { + new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering) + } + } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index daea2617e62ea..a9b905b0d1a63 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -993,7 +993,7 @@ abstract class RDD[T: ClassTag]( */ @Experimental def countApproxDistinct(p: Int, sp: Int): Long = { - require(p >= 4, s"p ($p) must be greater than 0") + require(p >= 4, s"p ($p) must be at least 4") require(sp <= 32, s"sp ($sp) cannot be greater than 32") require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)") val zeroCounter = new HyperLogLogPlus(p, sp) @@ -1064,11 +1064,10 @@ abstract class RDD[T: ClassTag]( // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. + // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, + // interpolate the number of partitions we need to try, but overestimate it by 50%. if (buf.size == 0) { - numPartsToTry = totalParts - 1 + numPartsToTry = partsScanned * 4 } else { numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt } @@ -1128,15 +1127,19 @@ abstract class RDD[T: ClassTag]( * @return an array of top elements */ def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = { - mapPartitions { items => - // Priority keeps the largest elements, so let's reverse the ordering. - val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) - Iterator.single(queue) - }.reduce { (queue1, queue2) => - queue1 ++= queue2 - queue1 - }.toArray.sorted(ord) + if (num == 0) { + Array.empty + } else { + mapPartitions { items => + // Priority keeps the largest elements, so let's reverse the ordering. + val queue = new BoundedPriorityQueue[T](num)(ord.reverse) + queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + Iterator.single(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray.sorted(ord) + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 162158babc35b..6d39a5e3fa64c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -24,38 +24,31 @@ package org.apache.spark.scheduler * from multiple applications are seen, the behavior is unspecified. */ private[spark] class ApplicationEventListener extends SparkListener { - var appName = "" - var sparkUser = "" - var startTime = -1L - var endTime = -1L - var viewAcls = "" - var adminAcls = "" - - def applicationStarted = startTime != -1 - - def applicationCompleted = endTime != -1 - - def applicationDuration: Long = { - val difference = endTime - startTime - if (applicationStarted && applicationCompleted && difference > 0) difference else -1L - } + var appName: Option[String] = None + var appId: Option[String] = None + var sparkUser: Option[String] = None + var startTime: Option[Long] = None + var endTime: Option[Long] = None + var viewAcls: Option[String] = None + var adminAcls: Option[String] = None override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { - appName = applicationStart.appName - startTime = applicationStart.time - sparkUser = applicationStart.sparkUser + appName = Some(applicationStart.appName) + appId = applicationStart.appId + startTime = Some(applicationStart.time) + sparkUser = Some(applicationStart.sparkUser) } override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { - endTime = applicationEnd.time + endTime = Some(applicationEnd.time) } override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { val environmentDetails = environmentUpdate.environmentDetails val allProperties = environmentDetails("Spark Properties").toMap - viewAcls = allProperties.getOrElse("spark.ui.view.acls", "") - adminAcls = allProperties.getOrElse("spark.admin.acls", "") + viewAcls = allProperties.get("spark.ui.view.acls") + adminAcls = allProperties.get("spark.admin.acls") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b86cfbfa48fbe..6fcf9e31543ed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -164,7 +164,7 @@ class DAGScheduler( */ def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics) + taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) implicit val timeout = Timeout(600 seconds) @@ -241,9 +241,9 @@ class DAGScheduler( callSite: CallSite) : Stage = { + val parentStages = getParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = - new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) + val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -677,7 +677,10 @@ class DAGScheduler( } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) { - listenerBus.post(SparkListenerTaskStart(task.stageId, taskInfo)) + // Note that there is a chance that this task is launched after the stage is cancelled. + // In that case, we wouldn't have the stage anymore in stageIdToStage. + val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) submitWaitingStages() } @@ -695,8 +698,8 @@ class DAGScheduler( // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" runningStages.foreach { stage => - stage.info.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.info)) + stage.latestInfo.stageFailed(stageFailedMessage) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) } listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } @@ -781,7 +784,16 @@ class DAGScheduler( logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() - var tasks = ArrayBuffer[Task[_]]() + + // First figure out the indexes of partition ids to compute. + val partitionsToCompute: Seq[Int] = { + if (stage.isShuffleMap) { + (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil) + } else { + val job = stage.resultOfJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) + } + } val properties = if (jobIdToActiveJob.contains(jobId)) { jobIdToActiveJob(stage.jobId).properties @@ -795,7 +807,8 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) + stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast @@ -826,20 +839,19 @@ class DAGScheduler( return } - if (stage.isShuffleMap) { - for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { - val locs = getPreferredLocs(stage.rdd, p) - val part = stage.rdd.partitions(p) - tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs) + val tasks: Seq[Task[_]] = if (stage.isShuffleMap) { + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) } } else { - // This is a final stage; figure out its job's missing partitions val job = stage.resultOfJob.get - for (id <- 0 until job.numPartitions if !job.finished(id)) { + partitionsToCompute.map { id => val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = getPreferredLocs(stage.rdd, p) - tasks += new ResultTask(stage.id, taskBinary, part, locs, id) + new ResultTask(stage.id, taskBinary, part, locs, id) } } @@ -869,11 +881,11 @@ class DAGScheduler( logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stage.info.submissionTime = Some(clock.getTime()) + stage.latestInfo.submissionTime = Some(clock.getTime()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should post // SparkListenerStageCompleted here in case there are no tasks to run. - listenerBus.post(SparkListenerStageCompleted(stage.info)) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) runningStages -= stage @@ -892,8 +904,9 @@ class DAGScheduler( // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { - listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, - event.taskMetrics)) + val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, + event.taskInfo, event.taskMetrics)) } if (!stageIdToStage.contains(task.stageId)) { @@ -902,14 +915,19 @@ class DAGScheduler( } val stage = stageIdToStage(task.stageId) - def markStageAsFinished(stage: Stage) = { - val serviceTime = stage.info.submissionTime match { + def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { + val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) case _ => "Unknown" } - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.info.completionTime = Some(clock.getTime()) - listenerBus.post(SparkListenerStageCompleted(stage.info)) + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTime()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage } event.reason match { @@ -924,7 +942,7 @@ class DAGScheduler( val name = acc.name.get val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) val stringValue = Accumulators.stringifyValue(acc.value) - stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) event.taskInfo.accumulables += AccumulableInfo(id, name, Some(stringPartialValue), stringValue) } @@ -935,8 +953,8 @@ class DAGScheduler( logError(s"Failed to update accumulators for $task", e) } } - listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, - event.taskMetrics)) + listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, + event.reason, event.taskInfo, event.taskMetrics)) stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => @@ -1027,30 +1045,39 @@ class DAGScheduler( stage.pendingTasks += task case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - // Mark the stage that the reducer was in as unrunnable val failedStage = stageIdToStage(task.stageId) - runningStages -= failedStage - // TODO: Cancel running tasks in the stage - logInfo("Marking " + failedStage + " (" + failedStage.name + - ") for resubmision due to a fetch failure") - // Mark the map whose fetch failed as broken in the map stage val mapStage = shuffleToMapStage(shuffleId) - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is possible + // the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, Some("Fetch failure")) + runningStages -= failedStage } - logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name + - "); marking it for resubmission") + if (failedStages.isEmpty && eventProcessActor != null) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. eventProcessActor may be // null during unit tests. + // TODO: Cancel running tasks in the stage import env.actorSystem.dispatcher + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") env.actorSystem.scheduler.scheduleOnce( RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages) } failedStages += failedStage failedStages += mapStage + + // Mark the map whose fetch failed as broken in the map stage + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { handleExecutorLost(bmAddress.executorId, Some(task.epoch)) @@ -1142,7 +1169,7 @@ class DAGScheduler( } val dependentJobs: Seq[ActiveJob] = activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq - failedStage.info.completionTime = Some(clock.getTime()) + failedStage.latestInfo.completionTime = Some(clock.getTime()) for (job <- dependentJobs) { failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") } @@ -1182,8 +1209,8 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.info.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.info)) + stage.latestInfo.stageFailed(failureReason) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 370fcd85aa680..64b32ae0edaac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -29,6 +29,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec +import org.apache.spark.SPARK_VERSION import org.apache.spark.util.{FileLogger, JsonProtocol, Utils} /** @@ -44,11 +45,14 @@ import org.apache.spark.util.{FileLogger, JsonProtocol, Utils} private[spark] class EventLoggingListener( appName: String, sparkConf: SparkConf, - hadoopConf: Configuration = SparkHadoopUtil.get.newConfiguration()) + hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ + def this(appName: String, sparkConf: SparkConf) = + this(appName, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) + private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) @@ -83,7 +87,7 @@ private[spark] class EventLoggingListener( sparkConf.get("spark.io.compression.codec", CompressionCodec.DEFAULT_COMPRESSION_CODEC) logger.newFile(COMPRESSION_CODEC_PREFIX + codec) } - logger.newFile(SPARK_VERSION_PREFIX + SparkContext.SPARK_VERSION) + logger.newFile(SPARK_VERSION_PREFIX + SPARK_VERSION) logger.newFile(LOG_PREFIX + logger.fileIndex) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index e9bfee2248e5b..29879b374b801 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -23,7 +23,7 @@ package org.apache.spark.scheduler */ private[spark] class JobWaiter[T]( dagScheduler: DAGScheduler, - jobId: Int, + val jobId: Int, totalTasks: Int, resultHandler: (Int, T) => Unit) extends JobListener { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index e41e0a9841691..a0be8307eff27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -31,4 +31,12 @@ private[spark] trait SchedulerBackend { def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = throw new UnsupportedOperationException def isReady(): Boolean = true + + /** + * The application ID associated with the job, if any. + * + * @return The application ID, or None if the backend does not provide an ID. + */ + def applicationId(): Option[String] = None + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index d01d318633877..86afe3bd5265f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -39,7 +39,8 @@ case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Propert case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent @DeveloperApi -case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent +case class SparkListenerTaskStart(stageId: Int, stageAttemptId: Int, taskInfo: TaskInfo) + extends SparkListenerEvent @DeveloperApi case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent @@ -47,6 +48,7 @@ case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListe @DeveloperApi case class SparkListenerTaskEnd( stageId: Int, + stageAttemptId: Int, taskType: String, reason: TaskEndReason, taskInfo: TaskInfo, @@ -65,25 +67,30 @@ case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(S extends SparkListenerEvent @DeveloperApi -case class SparkListenerBlockManagerAdded(blockManagerId: BlockManagerId, maxMem: Long) +case class SparkListenerBlockManagerAdded(time: Long, blockManagerId: BlockManagerId, maxMem: Long) extends SparkListenerEvent @DeveloperApi -case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) +case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockManagerId) extends SparkListenerEvent @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +/** + * Periodic updates from executors. + * @param execId executor id + * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics) + */ @DeveloperApi case class SparkListenerExecutorMetricsUpdate( execId: String, - taskMetrics: Seq[(Long, Int, TaskMetrics)]) + taskMetrics: Seq[(Long, Int, Int, TaskMetrics)]) extends SparkListenerEvent @DeveloperApi -case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String) - extends SparkListenerEvent +case class SparkListenerApplicationStart(appName: String, appId: Option[String], time: Long, + sparkUser: String) extends SparkListenerEvent @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 800905413d145..071568cdfb429 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -43,6 +43,9 @@ import org.apache.spark.util.CallSite * stage, the callSite gives the user code that created the RDD being shuffled. For a result * stage, the callSite gives the user code that executes the associated action (e.g. count()). * + * A single stage can consist of multiple attempts. In that case, the latestInfo field will + * be updated for each attempt. + * */ private[spark] class Stage( val id: Int, @@ -71,8 +74,8 @@ private[spark] class Stage( val name = callSite.shortForm val details = callSite.longForm - /** Pointer to the [StageInfo] object, set by DAGScheduler. */ - var info: StageInfo = StageInfo.fromStage(this) + /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ + var latestInfo: StageInfo = StageInfo.fromStage(this) def isAvailable: Boolean = { if (!isShuffleMap) { @@ -116,6 +119,7 @@ private[spark] class Stage( } } + /** Return a new attempt id, starting with 0. */ def newAttemptId(): Int = { val id = nextAttemptId nextAttemptId += 1 diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 2a407e47a05bd..c6dc3369ba5cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -29,6 +29,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, + val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,9 +57,15 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage): StageInfo = { + def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos - new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos, stage.details) + new StageInfo( + stage.id, + stage.attemptId, + stage.name, + numTasks.getOrElse(stage.numTasks), + rddInfos, + stage.details) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 1a0b877c8a5e1..1c1ce666eab0f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -64,4 +64,12 @@ private[spark] trait TaskScheduler { */ def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], blockManagerId: BlockManagerId): Boolean + + /** + * The application ID associated with the job, if any. + * + * @return The application ID, or None if the backend does not provide an ID. + */ + def applicationId(): Option[String] = None + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 6c0d1b2752a81..633e892554c50 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -333,12 +333,12 @@ private[spark] class TaskSchedulerImpl( execId: String, taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics blockManagerId: BlockManagerId): Boolean = { - val metricsWithStageIds = taskMetrics.flatMap { - case (id, metrics) => { + + val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { + taskMetrics.flatMap { case (id, metrics) => taskIdToTaskSetId.get(id) .flatMap(activeTaskSets.get) - .map(_.stageId) - .map(x => (id, x, metrics)) + .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) } } dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) @@ -491,6 +491,9 @@ private[spark] class TaskSchedulerImpl( } } } + + override def applicationId(): Option[String] = backend.applicationId() + } @@ -535,4 +538,5 @@ private[spark] object TaskSchedulerImpl { retval.toList } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index 613fa7850bb25..c3ad325156f53 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -31,9 +31,5 @@ private[spark] class TaskSet( val properties: Properties) { val id: String = stageId + "." + attempt - def kill(interruptThread: Boolean) { - tasks.foreach(_.kill(interruptThread)) - } - override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2a3711ae2a78c..9a0cb1c6c6ccd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -51,12 +51,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - // Submit tasks only after (registered resources / total expected resources) + // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. var minRegisteredRatio = math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) // Submit tasks after maxRegisteredWaitingTime milliseconds - // if minRegisteredRatio has not yet been reached + // if minRegisteredRatio has not yet been reached val maxRegisteredWaitingTime = conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() @@ -292,7 +292,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") conf.set("spark.ui.filters", filterName) conf.set(s"spark.$filterName.params", filterParams) - JettyUtils.addFilters(scheduler.sc.ui.getHandlers, conf) + scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index d99c76117c168..ee10aa061f4e9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -17,10 +17,10 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.{Logging, SparkContext, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class SimrSchedulerBackend( @@ -38,22 +38,25 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port"), + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, + sc.conf.get("spark.driver.host"), + sc.conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val conf = new Configuration() + val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) + val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") logInfo("Writing to HDFS file: " + driverFilePath) logInfo("Writing Akka address: " + driverUrl) - logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress) + logInfo("Writing Spark UI Address: " + appUIAddress) // Create temporary file to prevent race condition where executors get empty driverUrl file val temp = fs.create(tmpPath, true) temp.writeUTF(driverUrl) temp.writeInt(maxCores) - temp.writeUTF(sc.ui.appUIAddress) + temp.writeUTF(appUIAddress) temp.close() // "Atomic" rename @@ -61,9 +64,10 @@ private[spark] class SimrSchedulerBackend( } override def stop() { - val conf = new Configuration() + val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) fs.delete(new Path(driverFilePath), false) super.stop() } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 589dba2e40d20..2f45d192e1d4d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} @@ -34,6 +34,10 @@ private[spark] class SparkDeploySchedulerBackend( var client: AppClient = null var stopping = false var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ + var appId: String = _ + + val registrationLock = new Object() + var registrationDone = false val maxCores = conf.getOption("spark.cores.max").map(_.toInt) val totalExpectedCores = maxCores.getOrElse(0) @@ -42,8 +46,10 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - conf.get("spark.driver.host"), conf.get("spark.driver.port"), + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, + conf.get("spark.driver.host"), + conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") @@ -61,11 +67,15 @@ private[spark] class SparkDeploySchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) + val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") + val eventLogDir = sc.eventLogger.map(_.logDir) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) + appUIAddress, eventLogDir) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() + + waitForRegistration() } override def stop() { @@ -79,15 +89,19 @@ private[spark] class SparkDeploySchedulerBackend( override def connected(appId: String) { logInfo("Connected to Spark cluster with app ID " + appId) + this.appId = appId + notifyContext() } override def disconnected() { + notifyContext() if (!stopping) { logWarning("Disconnected from Spark cluster! Waiting for reconnection...") } } override def dead(reason: String) { + notifyContext() if (!stopping) { logError("Application has been killed. Reason: " + reason) scheduler.error(reason) @@ -114,4 +128,22 @@ private[spark] class SparkDeploySchedulerBackend( override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } + + override def applicationId(): Option[String] = Option(appId) + + private def waitForRegistration() = { + registrationLock.synchronized { + while (!registrationDone) { + registrationLock.wait() + } + } + } + + private def notifyContext() = { + registrationLock.synchronized { + registrationDone = true + registrationLock.notifyAll() + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 9f45400bcf852..64568409dbafd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -28,7 +28,7 @@ import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} -import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -71,9 +71,6 @@ private[spark] class CoarseMesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Int, String] val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( - "Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor")) val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) @@ -110,6 +107,11 @@ private[spark] class CoarseMesosSchedulerBackend( } def createCommand(offer: Offer, numCores: Int): CommandInfo = { + val executorSparkHome = conf.getOption("spark.mesos.executor.home") + .orElse(sc.getSparkHome()) + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() val extraClassPath = conf.getOption("spark.executor.extraClassPath") extraClassPath.foreach { cp => @@ -122,6 +124,12 @@ private[spark] class CoarseMesosSchedulerBackend( val extraLibraryPath = conf.getOption(libraryPathOption).map(p => s"-Djava.library.path=$p") val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ") + environment.addVariables( + Environment.Variable.newBuilder() + .setName("SPARK_EXECUTOR_OPTS") + .setValue(extraOpts) + .build()) + sc.executorEnvs.foreach { case (key, value) => environment.addVariables(Environment.Variable.newBuilder() .setName(key) @@ -130,25 +138,26 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) val uri = conf.get("spark.executor.uri", null) if (uri == null) { - val runScript = new File(sparkHome, "./bin/spark-class").getCanonicalPath + val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %s %d".format( - runScript, extraOpts, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( + runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( ("cd %s*; " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %s %d") - .format(basename, extraOpts, driverUrl, offer.getSlaveId.getValue, + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d") + .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } @@ -300,4 +309,5 @@ private[spark] class CoarseMesosSchedulerBackend( logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c717e7c621a8f..a9ef126f5de0e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -86,10 +86,26 @@ private[spark] class MesosSchedulerBackend( } def createExecutorInfo(execId: String): ExecutorInfo = { - val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( - "Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor")) + val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") + .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() + sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => + environment.addVariables( + Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) + } + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") + val extraLibraryPath = sc.conf.getOption("spark.executor.extraLibraryPath").map { lp => + s"-Djava.library.path=$lp" + } + val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ") + environment.addVariables( + Environment.Variable.newBuilder() + .setName("SPARK_EXECUTOR_OPTS") + .setValue(extraOpts) + .build()) sc.executorEnvs.foreach { case (key, value) => environment.addVariables(Environment.Variable.newBuilder() .setName(key) @@ -100,7 +116,7 @@ private[spark] class MesosSchedulerBackend( .setEnvironment(environment) val uri = sc.conf.get("spark.executor.uri", null) if (uri == null) { - command.setValue(new File(sparkHome, "/sbin/spark-executor").getCanonicalPath) + command.setValue(new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". @@ -333,4 +349,5 @@ private[spark] class MesosSchedulerBackend( // TODO: query Mesos for number of cores override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index bec9502f20466..9ea25c2bc7090 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -114,4 +114,5 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { localActor ! StatusUpdate(taskId, state, serializedData) } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 87ef9bb0b43c6..d6386f8c06fff 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,9 +27,9 @@ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.spark._ import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ -import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala similarity index 82% rename from core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala rename to core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index b8f5d3a5b02aa..439981d232349 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -15,22 +15,23 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.shuffle import java.io.File +import java.nio.ByteBuffer import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ -import org.apache.spark.Logging +import org.apache.spark.{SparkEnv, SparkConf, Logging} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleManager -import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup +import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup +import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} -import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.executor.ShuffleWriteMetrics /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -61,20 +62,18 @@ private[spark] trait ShuffleWriterGroup { * each block stored in each file. In order to find the location of a shuffle block, we search the * files within a ShuffleFileGroups associated with the block's reducer. */ -// TODO: Factor this into a separate class for each ShuffleManager implementation + private[spark] -class ShuffleBlockManager(blockManager: BlockManager, - shuffleManager: ShuffleManager) extends Logging { - def conf = blockManager.conf +class FileShuffleBlockManager(conf: SparkConf) + extends ShuffleBlockManager with Logging { + + private lazy val blockManager = SparkEnv.get.blockManager // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. - val consolidateShuffleFiles = + private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Are we using sort-based shuffle? - val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager] - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 /** @@ -93,22 +92,11 @@ class ShuffleBlockManager(blockManager: BlockManager, val completedMapTasks = new ConcurrentLinkedQueue[Int]() } - type ShuffleId = Int private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) - /** - * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle - * because it just writes a single file by itself. - */ - def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) - val shuffleState = shuffleStates(shuffleId) - shuffleState.completedMapTasks.add(mapId) - } - /** * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully @@ -168,7 +156,7 @@ class ShuffleBlockManager(blockManager: BlockManager, val filename = physicalFileName(shuffleId, bucketId, fileId) blockManager.diskBlockManager.getFile(filename) } - val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files) + val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) shuffleState.allFileGroups.add(fileGroup) fileGroup } @@ -179,19 +167,28 @@ class ShuffleBlockManager(blockManager: BlockManager, } } - /** - * Returns the physical file segment in which the given BlockId is located. - * This function should only be called if shuffle file consolidation is enabled, as it is - * an error condition if we don't find the expected block. - */ - def getBlockLocation(id: ShuffleBlockId): FileSegment = { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(id.shuffleId) - for (fileGroup <- shuffleState.allFileGroups) { - val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId) - if (segment.isDefined) { return segment.get } + override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { + val segment = getBlockData(blockId) + Some(segment.nioByteBuffer()) + } + + override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + if (consolidateShuffleFiles) { + // Search all file groups associated with this shuffle. + val shuffleState = shuffleStates(blockId.shuffleId) + val iter = shuffleState.allFileGroups.iterator + while (iter.hasNext) { + val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) + if (segmentOpt.isDefined) { + val segment = segmentOpt.get + return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length) + } + } + throw new IllegalStateException("Failed to find shuffle block: " + blockId) + } else { + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(file, 0, file.length) } - throw new IllegalStateException("Failed to find shuffle block: " + id) } /** Remove all the blocks / files and metadata related to a particular shuffle. */ @@ -207,14 +204,7 @@ class ShuffleBlockManager(blockManager: BlockManager, private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (sortBasedShuffle) { - // There's a single block ID for each map, plus an index file for it - for (mapId <- state.completedMapTasks) { - val blockId = new ShuffleBlockId(shuffleId, mapId, 0) - blockManager.diskBlockManager.getFile(blockId).delete() - blockManager.diskBlockManager.getFile(blockId.name + ".index").delete() - } - } else if (consolidateShuffleFiles) { + if (consolidateShuffleFiles) { for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { file.delete() } @@ -240,13 +230,13 @@ class ShuffleBlockManager(blockManager: BlockManager, shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } - def stop() { + override def stop() { metadataCleaner.cancel() } } private[spark] -object ShuffleBlockManager { +object FileShuffleBlockManager { /** * A group of shuffle files, one per reducer. * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala new file mode 100644 index 0000000000000..4ab34336d3f01 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io._ +import java.nio.ByteBuffer + +import org.apache.spark.SparkEnv +import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer} +import org.apache.spark.storage._ + +/** + * Create and maintain the shuffle blocks' mapping between logic block and physical file location. + * Data of shuffle blocks from the same map task are stored in a single consolidated data file. + * The offsets of the data blocks in the data file are stored in a separate index file. + * + * We use the name of the shuffle data's shuffleBlockId with reduce ID set to 0 and add ".data" + * as the filename postfix for data file, and ".index" as the filename postfix for index file. + * + */ +private[spark] +class IndexShuffleBlockManager extends ShuffleBlockManager { + + private lazy val blockManager = SparkEnv.get.blockManager + + /** + * Mapping to a single shuffleBlockId with reduce ID 0. + * */ + def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = { + ShuffleBlockId(shuffleId, mapId, 0) + } + + def getDataFile(shuffleId: Int, mapId: Int): File = { + blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0)) + } + + private def getIndexFile(shuffleId: Int, mapId: Int): File = { + blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0)) + } + + /** + * Remove data file and index file that contain the output data from one map. + * */ + def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { + var file = getDataFile(shuffleId, mapId) + if (file.exists()) { + file.delete() + } + + file = getIndexFile(shuffleId, mapId) + if (file.exists()) { + file.delete() + } + } + + /** + * Write an index file with the offsets of each block, plus a final offset at the end for the + * end of the output file. This will be used by getBlockLocation to figure out where each block + * begins and ends. + * */ + def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]) = { + val indexFile = getIndexFile(shuffleId, mapId) + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + try { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L + out.writeLong(offset) + + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } finally { + out.close() + } + } + + override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { + Some(getBlockData(blockId).nioByteBuffer()) + } + + override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + // The block is actually going to be a range of a single map output file for this map, so + // find out the consolidated file, then the offset within that from our index + val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + + val in = new DataInputStream(new FileInputStream(indexFile)) + try { + in.skip(blockId.reduceId * 8) + val offset = in.readLong() + val nextOffset = in.readLong() + new FileSegmentManagedBuffer( + getDataFile(blockId.shuffleId, blockId.mapId), + offset, + nextOffset - offset) + } finally { + in.close() + } + } + + override def stop() = {} +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala new file mode 100644 index 0000000000000..63863cc0250a3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.nio.ByteBuffer + +import org.apache.spark.network.ManagedBuffer +import org.apache.spark.storage.ShuffleBlockId + +private[spark] +trait ShuffleBlockManager { + type ShuffleId = Int + + /** + * Get shuffle block data managed by the local ShuffleBlockManager. + * @return Some(ByteBuffer) if block found, otherwise None. + */ + def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] + + def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 9c859b8b4a118..801ae54086053 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -49,8 +49,13 @@ private[spark] trait ShuffleManager { endPartition: Int, context: TaskContext): ShuffleReader[K, C] - /** Remove a shuffle's metadata from the ShuffleManager. */ - def unregisterShuffle(shuffleId: Int) + /** + * Remove a shuffle's metadata from the ShuffleManager. + * @return true if the metadata removed successfully, otherwise false. + */ + def unregisterShuffle(shuffleId: Int): Boolean + + def shuffleBlockManager: ShuffleBlockManager /** Shut down this ShuffleManager. */ def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 12b475658e29d..6cf9305977a3c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -21,10 +21,9 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { @@ -32,8 +31,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer, - shuffleMetrics: ShuffleReadMetrics) + serializer: Serializer) : Iterator[T] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) @@ -74,7 +72,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics) + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + SparkEnv.get.blockTransferService, + blockManager, + blocksByAddress, + serializer, + SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index df98d18fa8193..62e0629b34400 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -25,6 +25,9 @@ import org.apache.spark.shuffle._ * mapper (possibly reusing these across waves of tasks). */ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { + + private val fileShuffleBlockManager = new FileShuffleBlockManager(conf) + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, @@ -49,12 +52,21 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) : ShuffleWriter[K, V] = { - new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + new HashShuffleWriter( + shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Unit = {} + override def unregisterShuffle(shuffleId: Int): Boolean = { + shuffleBlockManager.removeShuffle(shuffleId) + } + + override def shuffleBlockManager: FileShuffleBlockManager = { + fileShuffleBlockManager + } /** Shut down this ShuffleManager. */ - override def stop(): Unit = {} + override def stop(): Unit = { + shuffleBlockManager.stop() + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 7bed97a63f0f6..88a5f1e5ddf58 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -36,10 +36,8 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser, - readMetrics) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 51e454d9313c9..4b9454d75abb7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -17,14 +17,15 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} -import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext} -import org.apache.spark.storage.{BlockObjectWriter} -import org.apache.spark.serializer.Serializer +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ +import org.apache.spark.storage.BlockObjectWriter private[spark] class HashShuffleWriter[K, V]( + shuffleBlockManager: FileShuffleBlockManager, handle: BaseShuffleHandle[K, V, _], mapId: Int, context: TaskContext) @@ -43,7 +44,6 @@ private[spark] class HashShuffleWriter[K, V]( metrics.shuffleWriteMetrics = Some(writeMetrics) private val blockManager = SparkEnv.get.blockManager - private val shuffleBlockManager = blockManager.shuffleBlockManager private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser, writeMetrics) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 6dcca47ea7c0c..b727438ae7e47 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -17,14 +17,17 @@ package org.apache.spark.shuffle.sort -import java.io.{DataInputStream, FileInputStream} +import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.{TaskContext, ShuffleDependency} import org.apache.spark.shuffle.hash.HashShuffleReader -import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId} -private[spark] class SortShuffleManager extends ShuffleManager { +private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { + + private val indexShuffleBlockManager = new IndexShuffleBlockManager() + private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() + /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ @@ -52,29 +55,29 @@ private[spark] class SortShuffleManager extends ShuffleManager { /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) : ShuffleWriter[K, V] = { - new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] + shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) + new SortShuffleWriter( + shuffleBlockManager, baseShuffleHandle, mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Unit = {} + override def unregisterShuffle(shuffleId: Int): Boolean = { + if (shuffleMapNumber.containsKey(shuffleId)) { + val numMaps = shuffleMapNumber.remove(shuffleId) + (0 until numMaps).map{ mapId => + shuffleBlockManager.removeDataByMap(shuffleId, mapId) + } + } + true + } - /** Shut down this ShuffleManager. */ - override def stop(): Unit = {} + override def shuffleBlockManager: IndexShuffleBlockManager = { + indexShuffleBlockManager + } - /** Get the location of a block in a map output file. Uses the index file we create for it. */ - def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = { - // The block is actually going to be a range of a single map output file for this map, so - // figure out the ID of the consolidated file, then the offset within that from our index - val consolidatedId = blockId.copy(reduceId = 0) - val indexFile = diskManager.getFile(consolidatedId.name + ".index") - val in = new DataInputStream(new FileInputStream(indexFile)) - try { - in.skip(blockId.reduceId * 8) - val offset = in.readLong() - val nextOffset = in.readLong() - new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset) - } finally { - in.close() - } + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockManager.stop() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 22f656fa371ea..89a78d6982ba0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,34 +17,25 @@ package org.apache.spark.shuffle.sort -import java.io.{BufferedOutputStream, File, FileOutputStream, DataOutputStream} - import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.shuffle.{IndexShuffleBlockManager, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( + shuffleBlockManager: IndexShuffleBlockManager, handle: BaseShuffleHandle[K, V, C], mapId: Int, context: TaskContext) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency - private val numPartitions = dep.partitioner.numPartitions private val blockManager = SparkEnv.get.blockManager - private val ser = Serializer.getSerializer(dep.serializer.orNull) - - private val conf = SparkEnv.get.conf - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 private var sorter: ExternalSorter[K, V, _] = null - private var outputFile: File = null - private var indexFile: File = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -74,17 +65,10 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter.insertAll(records) } - // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later - // serve different ranges of this file using an index file that we create at the end. - val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0) - - outputFile = blockManager.diskBlockManager.getFile(blockId) - indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index") - - val partitionLengths = sorter.writePartitionedFile(blockId, context) - - // Register our map output with the ShuffleBlockManager, which handles cleaning it over time - blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions) + val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) + val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) + val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) + shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) mapStatus = new MapStatus(blockManager.blockManagerId, partitionLengths.map(MapOutputTracker.compressSize)) @@ -100,13 +84,8 @@ private[spark] class SortShuffleWriter[K, V, C]( if (success) { return Option(mapStatus) } else { - // The map task failed, so delete our output file if we created one - if (outputFile != null) { - outputFile.delete() - } - if (indexFile != null) { - indexFile.delete() - } + // The map task failed, so delete our output data. + shuffleBlockManager.removeDataByMap(dep.shuffleId, mapId) return None } } finally { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala deleted file mode 100644 index ca60ec78b62ee..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.LinkedBlockingQueue -import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue -import scala.util.{Failure, Success} - -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.executor.ShuffleReadMetrics -import org.apache.spark.network.BufferMessage -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils - -/** - * A block fetcher iterator interface. There are two implementations: - * - * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. - * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. - * - * Eventually we would like the two to converge and use a single NIO-based communication layer, - * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), - * NIO would perform poorly and thus the need for the Netty OIO one. - */ - -private[storage] -trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { - def initialize() -} - - -private[storage] -object BlockFetcherIterator { - - /** - * A request to fetch blocks from a remote BlockManager. - * @param address remote BlockManager to fetch from. - * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. - */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { - val size = blocks.map(_._2).sum - } - - /** - * Result of a fetch from a remote block. A failure is represented as size == -1. - * @param blockId block id - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. - */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - - // TODO: Refactor this whole thing to make code more reusable. - class BasicBlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics) - extends BlockFetcherIterator { - - import blockManager._ - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - - // Total number blocks fetched (local + remote). Also number of FetchResults expected - protected var _numBlocksToFetch = 0 - - protected var startTime = System.currentTimeMillis - - // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[BlockId]() - - // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[BlockId]() - - // A queue to hold our results. - protected val results = new LinkedBlockingQueue[FetchResult] - - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - protected val fetchRequests = new Queue[FetchRequest] - - // Current bytes in flight from our requests - protected var bytesInFlight = 0L - - protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onComplete { - case Success(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can - // be incrementing bytes read at the same time (SPARK-2625). - readMetrics.remoteBytesRead += networkSize - readMetrics.remoteBlocksFetched += 1 - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - case Failure(exception) => { - logError("Could not get block(s) from " + cmId, exception) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - } - } - - protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) - - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - var totalBlocks = 0 - for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size - if (address == blockManagerId) { - // Filter out zero-sized blocks - localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) - _numBlocksToFetch += localBlocksToFetch.size - } else { - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(BlockId, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { - curBlocks += ((blockId, size)) - remoteBlocksToFetch += blockId - _numBlocksToFetch += 1 - curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= targetRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curBlocks = new ArrayBuffer[(BlockId, Long)] - logDebug(s"Creating fetch request of $curRequestSize at $address") - curRequestSize = 0 - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - totalBlocks + " blocks") - remoteRequests - } - - protected def getLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocksToFetch) { - try { - // getLocalFromDisk never return None but throws BlockException - val iter = getLocalFromDisk(id, serializer).get - // Pass 0 as size since it's not in flight - readMetrics.localBlocksFetched += 1 - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } catch { - case e: Exception => { - logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) - return - } - } - } - } - - override def initialize() { - // Split local and remote blocks. - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - - val numFetches = remoteRequests.size - fetchRequests.size - logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue - // as they arrive. - @volatile protected var resultsGotten = 0 - - override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - - override def next(): (BlockId, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (! result.failed) bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - } - // End of BasicBlockFetcherIterator - - class NettyBlockFetcherIterator( - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { - - override protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - - // This could throw a TimeoutException. In that case we will just retry the task. - val client = blockManager.nettyBlockClientFactory.createClient( - cmId.host, req.address.nettyPort) - val blocks = req.blocks.map(_._1.toString) - - client.fetchBlocks( - blocks, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = { - logError(s"Could not get block(s) from $cmId with error: $errorMsg") - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - - override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { - // Increment the reference count so the buffer won't be recycled. - // TODO: This could result in memory leaks when the task is stopped due to exception - // before the iterator is exhausted. - data.retain() - val buf = data.byteBuffer() - val blockSize = buf.remaining() - val bid = BlockId(blockId) - - // TODO: remove code duplication between here and BlockManager.dataDeserialization. - results.put(new FetchResult(bid, sizeMap(bid), () => { - def createIterator: Iterator[Any] = { - val stream = blockManager.wrapForCompression(bid, data.inputStream()) - serializer.newInstance().deserializeStream(stream).asIterator - } - new LazyInitIterator(createIterator) { - // Release the buffer when we are done traversing it. - override def close(): Unit = data.release() - } - })) - - readMetrics.synchronized { - readMetrics.remoteBytesRead += blockSize - readMetrics.remoteBlocksFetched += 1 - } - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - ) - } - } - // End of NettyBlockFetcherIterator -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index c1756ac905417..a83a3f468ae5f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -58,6 +58,11 @@ case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends Blo def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +@DeveloperApi +case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { + def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" +} + @DeveloperApi case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" @@ -92,6 +97,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r @@ -104,6 +110,8 @@ object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case SHUFFLE_DATA(shuffleId, mapId, reduceId) => + ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId, field) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 12a92d44f4c36..d1bee3d2c033c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,6 +20,8 @@ package org.apache.spark.storage import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} +import scala.concurrent.ExecutionContext.Implicits.global + import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ @@ -32,8 +34,6 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ -import org.apache.spark.network.netty.client.BlockFetchingClientFactory -import org.apache.spark.network.netty.server.BlockServer import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ @@ -60,18 +60,14 @@ private[spark] class BlockManager( defaultSerializer: Serializer, maxMemory: Long, val conf: SparkConf, - securityManager: SecurityManager, mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager) - extends BlockDataProvider with Logging { + shuffleManager: ShuffleManager, + blockTransferService: BlockTransferService) + extends BlockDataManager with Logging { - private val port = conf.getInt("spark.blockManager.port", 0) - val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) - val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) - val connectionManager = - new ConnectionManager(port, conf, securityManager, "Connection manager for block manager") + blockTransferService.init(this) - implicit val futureExecContext = connectionManager.futureExecContext + val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -85,36 +81,13 @@ private[spark] class BlockManager( val tachyonStorePath = s"$storeDir/$appFolderName/${this.executorId}" val tachyonMaster = conf.get("spark.tachyonStore.url", "tachyon://localhost:19998") val tachyonBlockManager = - new TachyonBlockManager(shuffleBlockManager, tachyonStorePath, tachyonMaster) + new TachyonBlockManager(this, tachyonStorePath, tachyonMaster) tachyonInitialized = true new TachyonStore(this, tachyonBlockManager) } - private val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) - - // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private[storage] val nettyBlockClientFactory: BlockFetchingClientFactory = { - if (useNetty) new BlockFetchingClientFactory(conf) else null - } - - private val nettyBlockServer: BlockServer = { - if (useNetty) { - val server = new BlockServer(conf, this) - logInfo(s"Created NettyBlockServer binding to port: ${server.port}") - server - } else { - null - } - } - - private val nettyPort: Int = if (useNetty) nettyBlockServer.port else 0 - val blockManagerId = BlockManagerId( - executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) - - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory - // for receiving shuffle outputs) - val maxBytesInFlight = conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 + executorId, blockTransferService.hostName, blockTransferService.port) // Whether to compress broadcast variables that are stored private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) @@ -157,11 +130,11 @@ private[spark] class BlockManager( master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, - securityManager: SecurityManager, mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager) = { + shuffleManager: ShuffleManager, + blockTransferService: BlockTransferService) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager, mapOutputTracker, shuffleManager) + conf, mapOutputTracker, shuffleManager, blockTransferService) } /** @@ -170,7 +143,6 @@ private[spark] class BlockManager( */ private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) - BlockManagerWorker.startBlockManagerWorker(this) } /** @@ -233,20 +205,33 @@ private[spark] class BlockManager( } } - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + /** + * Interface to get local block data. + * + * @return Some(buffer) if the block exists locally, and None if it doesn't. + */ + override def getBlockData(blockId: String): Option[ManagedBuffer] = { val bid = BlockId(blockId) if (bid.isShuffle) { - Left(diskBlockManager.getBlockLocation(bid)) + Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])) } else { val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { - Right(blockBytesOpt.get) + val buffer = blockBytesOpt.get + Some(new NioByteBufferManagedBuffer(buffer)) } else { - throw new BlockNotFoundException(blockId) + None } } } + /** + * Put the block locally, using the given storage level. + */ + override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = { + putBytes(BlockId(blockId), data.nioByteBuffer(), level) + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. @@ -354,10 +339,10 @@ private[spark] class BlockManager( * shuffle blocks. It is safe to do so without a lock on block info since disk store * never deletes (recent) items. */ - def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - diskStore.getValues(blockId, serializer).orElse { - throw new BlockException(blockId, s"Block $blockId not found on disk, though it should be") - } + def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { + val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + val is = wrapForCompression(blockId, buf.inputStream()) + Some(serializer.newInstance().deserializeStream(is).asIterator) } /** @@ -376,7 +361,8 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - diskStore.getBytes(blockId) match { + val shuffleBlockManager = shuffleManager.shuffleBlockManager + shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match { case Some(bytes) => Some(bytes) case None => @@ -527,8 +513,9 @@ private[spark] class BlockManager( val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) + val data = blockTransferService.fetchBlockSync( + loc.host, loc.port, blockId.toString).nioByteBuffer() + if (data != null) { if (asBlockResult) { return Some(new BlockResult( @@ -562,28 +549,6 @@ private[spark] class BlockManager( None } - /** - * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns - * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined - * fashion as they're received. Expects a size in bytes to be provided for each block fetched, - * so that we can control the maxMegabytesInFlight for the fetch. - */ - def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics): BlockFetcherIterator = { - val iter = - if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer, - readMetrics) - } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer, - readMetrics) - } - iter.initialize() - iter - } - def putIterator( blockId: BlockId, values: Iterator[Any], @@ -836,12 +801,15 @@ private[spark] class BlockManager( data.rewind() logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " + s"To node: $peer") - val putBlock = PutBlock(blockId, data, tLevel) - val cmId = new ConnectionManagerId(peer.host, peer.port) - val syncPutBlockSuccess = BlockManagerWorker.syncPutBlock(putBlock, cmId) - if (!syncPutBlockSuccess) { - logError(s"Failed to call syncPutBlock to $peer") + + try { + blockTransferService.uploadBlockSync( + peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) + } catch { + case e: Exception => + logError(s"Failed to replicate block to $peer", e) } + logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes." .format(blockId, (System.nanoTime - start) / 1e6, data.limit())) } @@ -1066,40 +1034,13 @@ private[spark] class BlockManager( bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - - def getIterator: Iterator[Any] = { - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator - } - - if (blockId.isShuffle) { - /* Reducer may need to read many local shuffle blocks and will wrap them into Iterators - * at the beginning. The wrapping will cost some memory (compression instance - * initialization, etc.). Reducer reads shuffle blocks one by one so we could do the - * wrapping lazily to save memory. */ - class LazyProxyIterator(f: => Iterator[Any]) extends Iterator[Any] { - lazy val proxy = f - override def hasNext: Boolean = proxy.hasNext - override def next(): Any = proxy.next() - } - new LazyProxyIterator(getIterator) - } else { - getIterator - } + val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) + serializer.newInstance().deserializeStream(stream).asIterator } def stop(): Unit = { - connectionManager.stop() - shuffleBlockManager.stop() + blockTransferService.stop() diskBlockManager.stop() - - if (nettyBlockClientFactory != null) { - nettyBlockClientFactory.stop() - } - if (nettyBlockServer != null) { - nettyBlockServer.stop() - } - actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index b1585bd8199d1..d4487fce49ab6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -36,11 +36,10 @@ import org.apache.spark.util.Utils class BlockManagerId private ( private var executorId_ : String, private var host_ : String, - private var port_ : Int, - private var nettyPort_ : Int - ) extends Externalizable { + private var port_ : Int) + extends Externalizable { - private def this() = this(null, null, 0, 0) // For deserialization only + private def this() = this(null, null, 0) // For deserialization only def executorId: String = executorId_ @@ -60,32 +59,28 @@ class BlockManagerId private ( def port: Int = port_ - def nettyPort: Int = nettyPort_ - override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) - out.writeInt(nettyPort_) } override def readExternal(in: ObjectInput) { executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() - nettyPort_ = in.readInt() } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort) + override def toString = s"BlockManagerId($executorId, $host, $port)" - override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port override def equals(that: Any) = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort + executorId == id.executorId && port == id.port && host == id.host case _ => false } @@ -100,11 +95,10 @@ private[spark] object BlockManagerId { * @param execId ID of the executor. * @param host Host name of the block manager. * @param port Port of the block manager. - * @param nettyPort Optional port for the Netty-based shuffle sender. * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int, nettyPort: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort)) + def apply(execId: String, host: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 669307765d1fa..2e262594b3538 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -27,7 +27,11 @@ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] -class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging { +class BlockManagerMaster( + var driverActor: ActorRef, + conf: SparkConf, + isDriver: Boolean) + extends Logging { private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) @@ -101,7 +105,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log def removeRdd(rddId: Int, blocking: Boolean) { val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { - case e: Throwable => logError("Failed to remove RDD " + rddId, e) + case e: Exception => + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") } if (blocking) { Await.result(future, timeout) @@ -112,7 +117,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log def removeShuffle(shuffleId: Int, blocking: Boolean) { val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { - case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e) + case e: Exception => + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") } if (blocking) { Await.result(future, timeout) @@ -124,9 +130,9 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log val future = askDriverWithReply[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { - case e: Throwable => - logError("Failed to remove broadcast " + broadcastId + - " with removeFromMaster = " + removeFromMaster, e) + case e: Exception => + logWarning(s"Failed to remove broadcast $broadcastId" + + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}") } if (blocking) { Await.result(future, timeout) @@ -194,7 +200,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log /** Stop the driver actor, called only on the Spark driver node */ def stop() { - if (driverActor != null) { + if (driverActor != null && isDriver) { tell(StopBlockManagerMaster) driverActor = null logInfo("BlockManagerMaster stopped") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 3ab07703b6f85..1a6c7cb24f9ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -203,7 +203,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockLocations.remove(blockId) } } - listenerBus.post(SparkListenerBlockManagerRemoved(blockManagerId)) + listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) } private def expireDeadHosts() { @@ -325,6 +325,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => @@ -340,9 +341,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus id.hostPort, Utils.bytesToString(maxMemSize))) blockManagerInfo(id) = - new BlockManagerInfo(id, System.currentTimeMillis(), maxMemSize, slaveActor) + new BlockManagerInfo(id, time, maxMemSize, slaveActor) } - listenerBus.post(SparkListenerBlockManagerAdded(id, maxMemSize)) + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } private def updateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index c194e0fed3367..14ae2f38c5670 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -21,7 +21,7 @@ import scala.concurrent.Future import akka.actor.{ActorRef, Actor} -import org.apache.spark.{Logging, MapOutputTracker} +import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.ActorLogReceive @@ -55,7 +55,7 @@ class BlockManagerSlaveActor( if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } - blockManager.shuffleBlockManager.removeShuffle(shuffleId) + SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) } case RemoveBroadcast(broadcastId, tellMaster) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala deleted file mode 100644 index bf002a42d5dc5..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -import org.apache.spark.Logging -import org.apache.spark.network._ -import org.apache.spark.util.Utils - -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.{Try, Failure, Success} - -/** - * A network interface for BlockManager. Each slave should have one - * BlockManagerWorker. - * - * TODO: Use event model. - */ -private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { - - blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) - - def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => { - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => { - logError("Exception handling buffer message", e) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) - } - } - } - case otherMessage: Any => { - logError("Unknown type message received: " + otherMessage) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) - } - } - } - - def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + pB + "]") - putBlock(pB.id, pB.data, pB.level) - None - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - logDebug("Received [" + gB + "]") - val buffer = getBlock(gB.id) - if (buffer == null) { - return None - } - Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) - } - case _ => None - } - } - - private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) - blockManager.putBytes(id, bytes, level) - logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(id: BlockId): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + id + " started from " + startTimeMs) - val buffer = blockManager.getLocalBytes(id) match { - case Some(bytes) => bytes - case None => null - } - logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - buffer - } -} - -private[spark] object BlockManagerWorker extends Logging { - private var blockManagerWorker: BlockManagerWorker = null - - def startBlockManagerWorker(manager: BlockManager) { - blockManagerWorker = new BlockManagerWorker(manager) - } - - def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromPutBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = Try(Await.result(connectionManager.sendMessageReliably( - toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) - resultMessage.isSuccess - } - - def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromGetBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = Try(Await.result(connectionManager.sendMessageReliably( - toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) - responseMessage match { - case Success(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - logDebug("Response message received " + bufferMessage) - BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { - logDebug("Found " + blockMessage) - return blockMessage.getData - }) - } - case Failure(exception) => logDebug("No response message received") - } - null - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index adda971fd7b47..9c469370ffe1f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -65,8 +65,6 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { /** * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. - * The given write metrics will be updated incrementally, but will not necessarily be current until - * commitAndClose is called. */ private[spark] class DiskBlockObjectWriter( blockId: BlockId, @@ -75,6 +73,8 @@ private[spark] class DiskBlockObjectWriter( bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, + // These write metrics concurrently shared with other active BlockObjectWriter's who + // are themselves performing writes. All updates must be relative. writeMetrics: ShuffleWriteMetrics) extends BlockObjectWriter(blockId) with Logging @@ -94,14 +94,30 @@ private[spark] class DiskBlockObjectWriter( private var fos: FileOutputStream = null private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null + private var initialized = false + + /** + * Cursors used to represent positions in the file. + * + * xxxxxxxx|--------|--- | + * ^ ^ ^ + * | | finalPosition + * | reportedPosition + * initialPosition + * + * initialPosition: Offset in the file where we start writing. Immutable. + * reportedPosition: Position at the time of the last update to the write metrics. + * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed. + * -----: Current writes to the underlying file. + * xxxxx: Existing contents of the file. + */ private val initialPosition = file.length() private var finalPosition: Long = -1 - private var initialized = false + private var reportedPosition = initialPosition /** Calling channel.position() to update the write metrics can be a little bit expensive, so we * only call it every N writes */ private var writesSinceMetricsUpdate = 0 - private var lastPosition = initialPosition override def open(): BlockObjectWriter = { fos = new FileOutputStream(file, true) @@ -140,17 +156,18 @@ private[spark] class DiskBlockObjectWriter( // serializer stream and the lower level stream. objOut.flush() bs.flush() - updateBytesWritten() close() } finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.shuffleBytesWritten += (finalPosition - reportedPosition) } // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { - writeMetrics.shuffleBytesWritten -= (lastPosition - initialPosition) + writeMetrics.shuffleBytesWritten -= (reportedPosition - initialPosition) if (initialized) { objOut.flush() @@ -189,10 +206,14 @@ private[spark] class DiskBlockObjectWriter( new FileSegment(file, initialPosition, finalPosition - initialPosition) } + /** + * Report the number of bytes written in this writer's shuffle write metrics. + * Note that this is only valid before the underlying streams are closed. + */ private def updateBytesWritten() { val pos = channel.position() - writeMetrics.shuffleBytesWritten += (pos - lastPosition) - lastPosition = pos + writeMetrics.shuffleBytesWritten += (pos - reportedPosition) + reportedPosition = pos } private def callWithTiming(f: => Unit) = { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index ec022ce9c048a..a715594f198c2 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -21,11 +21,9 @@ import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random, UUID} -import org.apache.spark.{SparkConf, SparkEnv, Logging} +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.network.netty.PathResolver import org.apache.spark.util.Utils -import org.apache.spark.shuffle.sort.SortShuffleManager /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -36,13 +34,11 @@ import org.apache.spark.shuffle.sort.SortShuffleManager * Block files are hashed among the directories listed in spark.local.dir (or in * SPARK_LOCAL_DIRS, if it's set). */ -private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, conf: SparkConf) - extends PathResolver with Logging { +private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf) + extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - - private val subDirsPerLocalDir = - shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64) + private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid @@ -56,26 +52,6 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, addShutdownHook() - /** - * Returns the physical file segment in which the given BlockId is located. If the BlockId has - * been mapped to a specific FileSegment by the shuffle layer, that will be returned. - * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId. - */ - def getBlockLocation(blockId: BlockId): FileSegment = { - val env = SparkEnv.get // NOTE: can be null in unit tests - if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) { - // For sort-based shuffle, let it figure out its blocks - val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager] - sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this) - } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) { - // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this - shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) - } else { - val file = getFile(blockId.name) - new FileSegment(file, 0, file.length()) - } - } - def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) @@ -105,7 +81,7 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, /** Check if disk block manager has a block. */ def containsBlock(blockId: BlockId): Boolean = { - getBlockLocation(blockId).file.exists() + getFile(blockId.name).exists() } /** List all the files currently stored on disk by the disk manager. */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index c83261dd91b36..e9304f6bb45d0 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, RandomAccessFile} +import java.io.{File, FileOutputStream, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode @@ -34,7 +34,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L) override def getSize(blockId: BlockId): Long = { - diskManager.getBlockLocation(blockId).length + diskManager.getFile(blockId.name).length } override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = { @@ -89,25 +89,33 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } } - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val segment = diskManager.getBlockLocation(blockId) - val channel = new RandomAccessFile(segment.file, "r").getChannel + private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { + val channel = new RandomAccessFile(file, "r").getChannel try { // For small files, directly read rather than memory map - if (segment.length < minMemoryMapBytes) { - val buf = ByteBuffer.allocate(segment.length.toInt) - channel.read(buf, segment.offset) + if (length < minMemoryMapBytes) { + val buf = ByteBuffer.allocate(length.toInt) + channel.read(buf, offset) buf.flip() Some(buf) } else { - Some(channel.map(MapMode.READ_ONLY, segment.offset, segment.length)) + Some(channel.map(MapMode.READ_ONLY, offset, length)) } } finally { channel.close() } } + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + val file = diskManager.getFile(blockId.name) + getBytes(file, 0, file.length) + } + + def getBytes(segment: FileSegment): Option[ByteBuffer] = { + getBytes(segment.file, segment.offset, segment.length) + } + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } @@ -117,24 +125,25 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc * shuffle short-circuit code. */ def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { + // TODO: Should bypass getBytes and use a stream based implementation, so that + // we won't use a lot of memory during e.g. external sort merge. getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) } override def remove(blockId: BlockId): Boolean = { - val fileSegment = diskManager.getBlockLocation(blockId) - val file = fileSegment.file - if (file.exists() && file.length() == fileSegment.length) { + val file = diskManager.getFile(blockId.name) + // If consolidation mode is used With HashShuffleMananger, the physical filename for the block + // is different from blockId.name. So the file returns here will not be exist, thus we avoid to + // delete the whole consolidated file by mistake. + if (file.exists()) { file.delete() } else { - if (fileSegment.length < file.length()) { - logWarning(s"Could not delete block associated with only a part of a file: $blockId") - } false } } override def contains(blockId: BlockId): Boolean = { - val file = diskManager.getBlockLocation(blockId).file + val file = diskManager.getFile(blockId.name) file.exists() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala new file mode 100644 index 0000000000000..c8e708aa6b1bc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet +import scala.collection.mutable.Queue + +import org.apache.spark.{TaskContext, Logging, SparkException} +import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} +import org.apache.spark.serializer.Serializer +import org.apache.spark.util.Utils + + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context [[TaskContext]], used for metrics update + * @param blockTransferService [[BlockTransferService]] for fetching remote blocks + * @param blockManager [[BlockManager]] for reading local blocks + * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. + * For each block we also require the size (in bytes as a long field) in + * order to throttle the memory usage. + * @param serializer serializer used to deserialize the data. + * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + */ +private[spark] +final class ShuffleBlockFetcherIterator( + context: TaskContext, + blockTransferService: BlockTransferService, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + serializer: Serializer, + maxBytesInFlight: Long) + extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { + + import ShuffleBlockFetcherIterator._ + + /** + * Total number of blocks to fetch. This can be smaller than the total number of blocks + * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * + * This should equal localBlocks.size + remoteBlocks.size. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks proccessed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTime = System.currentTimeMillis + + /** Local blocks to fetch, excluding zero-sized blocks. */ + private[this] val localBlocks = new ArrayBuffer[BlockId]() + + /** Remote blocks to fetch, excluding zero-sized blocks. */ + private[this] val remoteBlocks = new HashSet[BlockId]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + // the number of bytes in flight is limited to maxBytesInFlight + private[this] val fetchRequests = new Queue[FetchRequest] + + // Current bytes in flight from our requests + private[this] var bytesInFlight = 0L + + private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + + initialize() + + private[this] def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + + // so we can look up the size of each blockID + val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap + val blockIds = req.blocks.map(_._1.toString) + + blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, + new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), + () => serializer.newInstance().deserializeStream( + blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator + )) + shuffleMetrics.remoteBytesRead += data.size + shuffleMetrics.remoteBlocksFetched += 1 + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } + + override def onBlockFetchFailure(e: Throwable): Unit = { + logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + // Note that there is a chance that some blocks have been fetched successfully, but we + // still add them to the failed queue. This is fine because when the caller see a + // FetchFailedException, it is going to fail the entire task anyway. + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } + } + ) + } + + private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + + // Tracks total number of blocks (including zero sized blocks) + var totalBlocks = 0 + for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size + if (address == blockManager.blockManagerId) { + // Filter out zero-sized blocks + localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + numBlocksToFetch += localBlocks.size + } else { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(BlockId, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + // Skip empty blocks + if (size > 0) { + curBlocks += ((blockId, size)) + remoteBlocks += blockId + numBlocksToFetch += 1 + curRequestSize += size + } else if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } + if (curRequestSize >= targetRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curBlocks = new ArrayBuffer[(BlockId, Long)] + logDebug(s"Creating fetch request of $curRequestSize at $address") + curRequestSize = 0 + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + remoteRequests + } + + private[this] def fetchLocalBlocks() { + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlocks) { + try { + shuffleMetrics.localBlocksFetched += 1 + results.put(new FetchResult( + id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get)) + logDebug("Got local block " + id) + } catch { + case e: Exception => + logError(s"Error occurred while fetching local blocks", e) + results.put(new FetchResult(id, -1, null)) + return + } + } + } + + private[this] def initialize(): Unit = { + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (fetchRequests.nonEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + + val numFetches = remoteRequests.size - fetchRequests.size + logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + fetchLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + override def next(): (BlockId, Option[Iterator[Any]]) = { + numBlocksProcessed += 1 + val startFetchWait = System.currentTimeMillis() + val result = results.take() + val stopFetchWait = System.currentTimeMillis() + shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) + if (!result.failed) { + bytesInFlight -= result.size + } + // Send fetch requests up to maxBytesInFlight + while (fetchRequests.nonEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } +} + + +private[storage] +object ShuffleBlockFetcherIterator { + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { + val size = blocks.map(_._2).sum + } + + /** + * Result of a fetch from a remote block. A failure is represented as size == -1. + * @param blockId block id + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param deserialize closure to return the result in the form of an Iterator. + */ + class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index a6cbe3aa440ff..6908a59a79e60 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.Utils * @param rootDirs The directories to use for storing block files. Data will be hashed among these. */ private[spark] class TachyonBlockManager( - shuffleManager: ShuffleBlockManager, + blockManager: BlockManager, rootDirs: String, val master: String) extends Logging { @@ -49,7 +49,7 @@ private[spark] class TachyonBlockManager( private val MAX_DIR_CREATION_ATTEMPTS = 10 private val subDirsPerTachyonDir = - shuffleManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt + blockManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; // then, inside this directory, create multiple subdirectories that we will hash files into, diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala deleted file mode 100644 index aa83ea90ee9ee..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.ArrayBlockingQueue - -import akka.actor._ -import org.apache.spark.shuffle.hash.HashShuffleManager -import util.Random - -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} -import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.KryoSerializer - -/** - * This class tests the BlockManager and MemoryStore for thread safety and - * deadlocks. It spawns a number of producer and consumer threads. Producer - * threads continuously pushes blocks into the BlockManager and consumer - * threads continuously retrieves the blocks form the BlockManager and tests - * whether the block is correct or not. - */ -private[spark] object ThreadingTest { - - val numProducers = 5 - val numBlocksPerProducer = 20000 - - private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { - val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100) - - override def run() { - for (i <- 1 to numBlocksPerProducer) { - val blockId = TestBlockId("b-" + id + "-" + i) - val blockSize = Random.nextInt(1000) - val block = (1 to blockSize).map(_ => Random.nextInt()) - val level = randomLevel() - val startTime = System.currentTimeMillis() - manager.putIterator(blockId, block.iterator, level, tellMaster = true) - println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") - queue.add((blockId, block)) - } - println("Producer thread " + id + " terminated") - } - - def randomLevel(): StorageLevel = { - math.abs(Random.nextInt()) % 4 match { - case 0 => StorageLevel.MEMORY_ONLY - case 1 => StorageLevel.MEMORY_ONLY_SER - case 2 => StorageLevel.MEMORY_AND_DISK - case 3 => StorageLevel.MEMORY_AND_DISK_SER - } - } - } - - private[spark] class ConsumerThread( - manager: BlockManager, - queue: ArrayBlockingQueue[(BlockId, Seq[Int])] - ) extends Thread { - var numBlockConsumed = 0 - - override def run() { - println("Consumer thread started") - while(numBlockConsumed < numBlocksPerProducer) { - val (blockId, block) = queue.take() - val startTime = System.currentTimeMillis() - manager.get(blockId) match { - case Some(retrievedBlock) => - assert(retrievedBlock.data.toList.asInstanceOf[List[Int]] == block.toList, - "Block " + blockId + " did not match") - println("Got block " + blockId + " in " + - (System.currentTimeMillis - startTime) + " ms") - case None => - assert(false, "Block " + blockId + " could not be retrieved") - } - numBlockConsumed += 1 - } - println("Consumer thread terminated") - } - } - - def main(args: Array[String]) { - System.setProperty("spark.kryoserializer.buffer.mb", "1") - val actorSystem = ActorSystem("test") - val conf = new SparkConf() - val serializer = new KryoSerializer(conf) - val blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf) - val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf)) - val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) - val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) - producers.foreach(_.start) - consumers.foreach(_.start) - producers.foreach(_.join) - consumers.foreach(_.join) - blockManager.stop() - blockManagerMaster.stop() - actorSystem.shutdown() - actorSystem.awaitTermination() - println("Everything stopped.") - println( - "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index bee6dad3387e5..f0006b42aee4f 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -232,7 +232,7 @@ private[spark] object UIUtils extends Logging { def listingTable[T]( headers: Seq[String], generateDataRow: T => Seq[Node], - data: Seq[T], + data: Iterable[T], fixedWidth: Boolean = false): Seq[Node] = { var listingTableClass = TABLE_CLASS diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 02df4e8fe61af..b0e3bb3b552fd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -21,7 +21,6 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.storage.StorageLevel import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 0cc51c873727d..2987dc04494a5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -24,8 +24,8 @@ import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils -/** Page showing executor summary */ -private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { +/** Stage summary grouped by executors. */ +private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobProgressTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { @@ -65,7 +65,7 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { executorIdToAddress.put(executorId, address) } - listener.stageIdToData.get(stageId) match { + listener.stageIdToData.get((stageId, stageAttemptId)) match { case Some(stageData: StageUIData) => stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 74cd637d88155..eaeb861f59e5a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer, Map} +import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -43,12 +43,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // How many stages to remember val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) - val activeStages = HashMap[Int, StageInfo]() + // Map from stageId to StageInfo + val activeStages = new HashMap[Int, StageInfo] + + // Map from (stageId, attemptId) to StageUIData + val stageIdToData = new HashMap[(Int, Int), StageUIData] + val completedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() - val stageIdToData = new HashMap[Int, StageUIData] - + // Map from pool name to a hash map (map from stage id to StageInfo). val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() val executorIdToBlockManagerId = HashMap[String, BlockManagerId]() @@ -59,9 +63,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { val stage = stageCompleted.stageInfo - val stageId = stage.stageId - val stageData = stageIdToData.getOrElseUpdate(stageId, { - logWarning("Stage completed for unknown stage " + stageId) + val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { + logWarning("Stage completed for unknown stage " + stage.stageId) new StageUIData }) @@ -69,8 +72,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.accumulables(id) = info } - poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId)) - activeStages.remove(stageId) + poolToActiveStages.get(stageData.schedulingPool).foreach { hashMap => + hashMap.remove(stage.stageId) + } + activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage trimIfNecessary(completedStages) @@ -84,7 +89,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { if (stages.size > retainedStages) { val toRemove = math.max(retainedStages / 10, 1) - stages.take(toRemove).foreach { s => stageIdToData.remove(s.stageId) } + stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) } stages.trimStart(toRemove) } } @@ -98,21 +103,21 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) - val stageData = stageIdToData.getOrElseUpdate(stage.stageId, new StageUIData) + val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData) stageData.schedulingPool = poolName stageData.description = Option(stageSubmitted.properties).flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } - val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]()) + val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]) stages(stage.stageId) = stage } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { - val stageData = stageIdToData.getOrElseUpdate(taskStart.stageId, { + val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { logWarning("Task start for unknown stage " + taskStart.stageId) new StageUIData }) @@ -128,8 +133,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val info = taskEnd.taskInfo - if (info != null) { - val stageData = stageIdToData.getOrElseUpdate(taskEnd.stageId, { + // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task + // compeletion event is for. Let's just drop it here. This means we might have some speculation + // tasks on the web ui that's never marked as complete. + if (info != null && taskEnd.stageAttemptId != -1) { + val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), { logWarning("Task end for unknown stage " + taskEnd.stageId) new StageUIData }) @@ -222,8 +230,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { - for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) { - val stageData = stageIdToData.getOrElseUpdate(sid, { + for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), { logWarning("Metrics update for task in unknown stage " + sid) new StageUIData }) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index d4eb02722ad12..db01be596e073 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -34,7 +34,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { val stageId = request.getParameter("id").toInt - val stageDataOption = listener.stageIdToData.get(stageId) + val stageAttemptId = request.getParameter("attempt").toInt + val stageDataOption = listener.stageIdToData.get((stageId, stageAttemptId)) if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) { val content = @@ -42,14 +43,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {

Summary Metrics

No tasks have started yet

Tasks

No tasks have started yet - return UIUtils.headerSparkPage("Details for Stage %s".format(stageId), content, parent) + return UIUtils.headerSparkPage( + s"Details for Stage $stageId (Attempt $stageAttemptId)", content, parent) } val stageData = stageDataOption.get val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) val numCompleted = tasks.count(_.taskInfo.finished) - val accumulables = listener.stageIdToData(stageId).accumulables + val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables val hasInput = stageData.inputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 @@ -211,7 +213,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def quantileRow(data: Seq[Node]): Seq[Node] = {data} Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } - val executorTable = new ExecutorTable(stageId, parent) + + val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) val maybeAccumulableTable: Seq[Node] = if (accumulables.size > 0) {

Accumulators

++ accumulableTable } else Seq() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 16ad0df45aa0d..2e67310594784 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -97,8 +97,8 @@ private[ui] class StageTableBase( } // scalastyle:on - val nameLinkUri ="%s/stages/stage?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" + .format(UIUtils.prependBaseUri(parent.basePath), s.stageId, s.attemptId) val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -121,7 +121,7 @@ private[ui] class StageTableBase( } val stageDesc = for { - stageData <- listener.stageIdToData.get(s.stageId) + stageData <- listener.stageIdToData.get((s.stageId, s.attemptId)) desc <- stageData.description } yield {
{desc}
@@ -131,7 +131,7 @@ private[ui] class StageTableBase( } protected def stageRow(s: StageInfo): Seq[Node] = { - val stageDataOption = listener.stageIdToData.get(s.stageId) + val stageDataOption = listener.stageIdToData.get((s.stageId, s.attemptId)) if (stageDataOption.isEmpty) { return {s.stageId}No data available for this stage } @@ -154,7 +154,11 @@ private[ui] class StageTableBase( val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" - {s.stageId} ++ + {if (s.attemptId > 0) { + {s.stageId} (retry {s.attemptId}) + } else { + {s.stageId} + }} ++ {if (isFairScheduler) { info.numCachedPartitions > 0 } + // Remove all partitions that are no longer cached in current completed stage + val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet + _rddInfoMap.retain { case (id, info) => + !completedRddIds.contains(id) || info.numCachedPartitions > 0 + } } override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index d6afb73b74242..e2d32c859bbda 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -27,7 +27,7 @@ import akka.pattern.ask import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} /** * Various utility classes for working with Akka. @@ -192,10 +192,11 @@ private[spark] object AkkaUtils extends Logging { } def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { + val driverActorSystemName = SparkEnv.driverActorSystemName val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") - val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" + val url = s"akka.tcp://$driverActorSystemName@$driverHost:$driverPort/user/$name" val timeout = AkkaUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 2e8fbf5a91ee7..6d1fc05a15d2c 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -41,18 +41,40 @@ import org.apache.spark.io.CompressionCodec private[spark] class FileLogger( logDir: String, sparkConf: SparkConf, - hadoopConf: Configuration = SparkHadoopUtil.get.newConfiguration(), + hadoopConf: Configuration, outputBufferSize: Int = 8 * 1024, // 8 KB compress: Boolean = false, overwrite: Boolean = true, dirPermissions: Option[FsPermission] = None) extends Logging { + def this( + logDir: String, + sparkConf: SparkConf, + compress: Boolean = false, + overwrite: Boolean = true) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, + overwrite = overwrite) + } + private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } - private val fileSystem = Utils.getHadoopFileSystem(logDir) + /** + * To avoid effects of FileSystem#close or FileSystem.closeAll called from other modules, + * create unique FileSystem instance only for FileLogger + */ + private val fileSystem = { + val conf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val logUri = new URI(logDir) + val scheme = logUri.getScheme + if (scheme == "hdfs") { + conf.setBoolean("fs.hdfs.impl.disable.cache", true) + } + FileSystem.get(logUri, conf) + } + var fileIndex = 0 // Only used if compression is enabled diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 1e18ec688c40d..b0754e3ce10db 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -96,6 +96,7 @@ private[spark] object JsonProtocol { val taskInfo = taskStart.taskInfo ("Event" -> Utils.getFormattedClassName(taskStart)) ~ ("Stage ID" -> taskStart.stageId) ~ + ("Stage Attempt ID" -> taskStart.stageAttemptId) ~ ("Task Info" -> taskInfoToJson(taskInfo)) } @@ -112,6 +113,7 @@ private[spark] object JsonProtocol { val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing ("Event" -> Utils.getFormattedClassName(taskEnd)) ~ ("Stage ID" -> taskEnd.stageId) ~ + ("Stage Attempt ID" -> taskEnd.stageAttemptId) ~ ("Task Type" -> taskEnd.taskType) ~ ("Task End Reason" -> taskEndReason) ~ ("Task Info" -> taskInfoToJson(taskInfo)) ~ @@ -150,13 +152,15 @@ private[spark] object JsonProtocol { val blockManagerId = blockManagerIdToJson(blockManagerAdded.blockManagerId) ("Event" -> Utils.getFormattedClassName(blockManagerAdded)) ~ ("Block Manager ID" -> blockManagerId) ~ - ("Maximum Memory" -> blockManagerAdded.maxMem) + ("Maximum Memory" -> blockManagerAdded.maxMem) ~ + ("Timestamp" -> blockManagerAdded.time) } def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = { val blockManagerId = blockManagerIdToJson(blockManagerRemoved.blockManagerId) ("Event" -> Utils.getFormattedClassName(blockManagerRemoved)) ~ - ("Block Manager ID" -> blockManagerId) + ("Block Manager ID" -> blockManagerId) ~ + ("Timestamp" -> blockManagerRemoved.time) } def unpersistRDDToJson(unpersistRDD: SparkListenerUnpersistRDD): JValue = { @@ -167,6 +171,7 @@ private[spark] object JsonProtocol { def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = { ("Event" -> Utils.getFormattedClassName(applicationStart)) ~ ("App Name" -> applicationStart.appName) ~ + ("App ID" -> applicationStart.appId.map(JString(_)).getOrElse(JNothing)) ~ ("Timestamp" -> applicationStart.time) ~ ("User" -> applicationStart.sparkUser) } @@ -187,6 +192,7 @@ private[spark] object JsonProtocol { val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing) val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing) ("Stage ID" -> stageInfo.stageId) ~ + ("Stage Attempt ID" -> stageInfo.attemptId) ~ ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ @@ -292,8 +298,7 @@ private[spark] object JsonProtocol { def blockManagerIdToJson(blockManagerId: BlockManagerId): JValue = { ("Executor ID" -> blockManagerId.executorId) ~ ("Host" -> blockManagerId.host) ~ - ("Port" -> blockManagerId.port) ~ - ("Netty Port" -> blockManagerId.nettyPort) + ("Port" -> blockManagerId.port) } def jobResultToJson(jobResult: JobResult): JValue = { @@ -419,8 +424,9 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") - SparkListenerTaskStart(stageId, taskInfo) + SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } def taskGettingResultFromJson(json: JValue): SparkListenerTaskGettingResult = { @@ -430,11 +436,12 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") val taskMetrics = taskMetricsFromJson(json \ "Task Metrics") - SparkListenerTaskEnd(stageId, taskType, taskEndReason, taskInfo, taskMetrics) + SparkListenerTaskEnd(stageId, stageAttemptId, taskType, taskEndReason, taskInfo, taskMetrics) } def jobStartFromJson(json: JValue): SparkListenerJobStart = { @@ -462,12 +469,14 @@ private[spark] object JsonProtocol { def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] - SparkListenerBlockManagerAdded(blockManagerId, maxMem) + val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + SparkListenerBlockManagerAdded(time, blockManagerId, maxMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") - SparkListenerBlockManagerRemoved(blockManagerId) + val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + SparkListenerBlockManagerRemoved(time, blockManagerId) } def unpersistRDDFromJson(json: JValue): SparkListenerUnpersistRDD = { @@ -476,9 +485,10 @@ private[spark] object JsonProtocol { def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = { val appName = (json \ "App Name").extract[String] + val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String]) val time = (json \ "Timestamp").extract[Long] val sparkUser = (json \ "User").extract[String] - SparkListenerApplicationStart(appName, time, sparkUser) + SparkListenerApplicationStart(appName, appId, time, sparkUser) } def applicationEndFromJson(json: JValue): SparkListenerApplicationEnd = { @@ -492,6 +502,7 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] + val attemptId = (json \ "Attempt ID").extractOpt[Int].getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson(_)) @@ -504,7 +515,7 @@ private[spark] object JsonProtocol { case None => Seq[AccumulableInfo]() } - val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos, details) + val stageInfo = new StageInfo(stageId, attemptId, stageName, numTasks, rddInfos, details) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason @@ -638,8 +649,7 @@ private[spark] object JsonProtocol { val executorId = (json \ "Executor ID").extract[String] val host = (json \ "Host").extract[String] val port = (json \ "Port").extract[Int] - val nettyPort = (json \ "Netty Port").extract[Int] - BlockManagerId(executorId, host, port, nettyPort) + BlockManagerId(executorId, host, port) } def jobResultFromJson(json: JValue): JobResult = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d6d74ce269219..79943766d0f0f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -20,9 +20,11 @@ package org.apache.spark.util import java.io._ import java.net._ import java.nio.ByteBuffer -import java.util.{Locale, Random, UUID} +import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} +import org.apache.log4j.PropertyConfigurator + import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer @@ -34,6 +36,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} @@ -52,11 +55,6 @@ private[spark] case class CallSite(shortForm: String, longForm: String) private[spark] object Utils extends Logging { val random = new Random() - def sparkBin(sparkHome: String, which: String): File = { - val suffix = if (isWindows) ".cmd" else "" - new File(sparkHome + File.separator + "bin", which + suffix) - } - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -162,30 +160,6 @@ private[spark] object Utils extends Logging { } } - def isAlpha(c: Char): Boolean = { - (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') - } - - /** Split a string into words at non-alphabetic characters */ - def splitWords(s: String): Seq[String] = { - val buf = new ArrayBuffer[String] - var i = 0 - while (i < s.length) { - var j = i - while (j < s.length && isAlpha(s.charAt(j))) { - j += 1 - } - if (j > i) { - buf += s.substring(i, j) - } - i = j - while (i < s.length && !isAlpha(s.charAt(i))) { - i += 1 - } - } - buf - } - private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() @@ -347,7 +321,8 @@ private[spark] object Utils extends Logging { * Throws SparkException if the target file already exists and has different contents than * the requested file. */ - def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) { + def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager, + hadoopConf: Configuration) { val filename = url.split("/").last val tempDir = getLocalDir(conf) val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) @@ -419,7 +394,7 @@ private[spark] object Utils extends Logging { } case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val fs = getHadoopFileSystem(uri) + val fs = getHadoopFileSystem(uri, hadoopConf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) @@ -830,14 +805,6 @@ private[spark] object Utils extends Logging { } } - /** - * Execute a command in the current working directory, throwing an exception if it completes - * with an exit code other than 0. - */ - def execute(command: Seq[String]) { - execute(command, new File(".")) - } - /** * Execute a command and get its output, throwing an exception if it yields a code other than 0. */ @@ -869,6 +836,7 @@ private[spark] object Utils extends Logging { val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output if (exitCode != 0) { + logError(s"Process $command exited with code $exitCode: ${output}") throw new SparkException("Process " + command + " exited with code " + exitCode) } output.toString @@ -899,8 +867,8 @@ private[spark] object Utils extends Logging { */ def getCallSite: CallSite = { val trace = Thread.currentThread.getStackTrace() - .filterNot { ste:StackTraceElement => - // When running under some profilers, the current stack trace might contain some bogus + .filterNot { ste:StackTraceElement => + // When running under some profilers, the current stack trace might contain some bogus // frames. This is intended to ensure that we don't crash in these situations by // ignoring any frames that we can't examine. (ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")) @@ -1216,15 +1184,15 @@ private[spark] object Utils extends Logging { /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ - def getHadoopFileSystem(path: URI): FileSystem = { - FileSystem.get(path, SparkHadoopUtil.get.newConfiguration()) + def getHadoopFileSystem(path: URI, conf: Configuration): FileSystem = { + FileSystem.get(path, conf) } /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ - def getHadoopFileSystem(path: String): FileSystem = { - getHadoopFileSystem(new URI(path)) + def getHadoopFileSystem(path: String, conf: Configuration): FileSystem = { + getHadoopFileSystem(new URI(path), conf) } /** @@ -1301,7 +1269,7 @@ private[spark] object Utils extends Logging { } } - /** + /** * Execute the given block, logging and re-throwing any uncaught exception. * This is particularly useful for wrapping code that runs in a thread, to ensure * that exceptions are printed, and to avoid having to catch Throwable. @@ -1479,4 +1447,39 @@ private[spark] object Utils extends Logging { } } + /** + * config a log4j properties used for testsuite + */ + def configTestLog4j(level: String): Unit = { + val pro = new Properties() + pro.put("log4j.rootLogger", s"$level, console") + pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") + pro.put("log4j.appender.console.target", "System.err") + pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") + pro.put("log4j.appender.console.layout.ConversionPattern", + "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") + PropertyConfigurator.configure(pro) + } + +} + +/** + * A utility class to redirect the child process's stdout or stderr. + */ +private[spark] class RedirectThread(in: InputStream, out: OutputStream, name: String) + extends Thread(name) { + + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + out.write(buf, 0, len) + out.flush() + len = in.read(buf) + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 9f85b94a70800..8a015c1d26a96 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -413,7 +413,12 @@ class ExternalAppendOnlyMap[K, V, C]( extends Iterator[(K, C)] { private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 - assert(file.length() == batchOffsets(batchOffsets.length - 1)) + assert(file.length() == batchOffsets.last, + "File length is not equal to the last batch offset:\n" + + s" file length = ${file.length}\n" + + s" last batch offset = ${batchOffsets.last}\n" + + s" all batch offsets = ${batchOffsets.mkString(",")}" + ) private var batchIndex = 0 // Which batch we're in private var fileStream: FileInputStream = null diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 5d8a648d9551e..782b979e2e93d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -719,20 +719,20 @@ private[spark] class ExternalSorter[K, V, C]( def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) /** - * Write all the data added into this ExternalSorter into a file in the disk store, creating - * an .index file for it as well with the offsets of each partition. This is called by the - * SortShuffleWriter and can go through an efficient path of just concatenating binary files - * if we decided to avoid merge-sorting. + * Write all the data added into this ExternalSorter into a file in the disk store. This is + * called by the SortShuffleWriter and can go through an efficient path of just concatenating + * binary files if we decided to avoid merge-sorting. * * @param blockId block ID to write to. The index file will be blockId.name + ".index". * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = { - val outputFile = blockManager.diskBlockManager.getFile(blockId) + def writePartitionedFile( + blockId: BlockId, + context: TaskContext, + outputFile: File): Array[Long] = { // Track location of each range in the output file - val offsets = new Array[Long](numPartitions + 1) val lengths = new Array[Long](numPartitions) if (bypassMergeSort && partitionWriters != null) { @@ -750,7 +750,6 @@ private[spark] class ExternalSorter[K, V, C]( in.close() in = null lengths(i) = size - offsets(i + 1) = offsets(i) + lengths(i) } } finally { if (out != null) { @@ -772,11 +771,7 @@ private[spark] class ExternalSorter[K, V, C]( } writer.commitAndClose() val segment = writer.fileSegment() - offsets(id + 1) = segment.offset + segment.length lengths(id) = segment.length - } else { - // The partition is empty; don't create a new writer to avoid writing headers, etc - offsets(id + 1) = offsets(id) } } } @@ -784,23 +779,6 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled context.taskMetrics.diskBytesSpilled += diskBytesSpilled - // Write an index file with the offsets of each block, plus a final offset at the end for the - // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure - // out where each block begins and ends. - - val diskBlockManager = blockManager.diskBlockManager - val indexFile = diskBlockManager.getFile(blockId.name + ".index") - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { - var i = 0 - while (i < numPartitions + 1) { - out.writeLong(offsets(i)) - i += 1 - } - } finally { - out.close() - } - lengths } @@ -811,7 +789,7 @@ private[spark] class ExternalSorter[K, V, C]( if (writer.isOpen) { writer.commitAndClose() } - blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]] + blockManager.diskStore.getValues(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]] } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala new file mode 100644 index 0000000000000..daac6f971eb20 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.io.OutputStream + +import scala.collection.mutable.ArrayBuffer + + +/** + * An OutputStream that writes to fixed-size chunks of byte arrays. + * + * @param chunkSize size of each chunk, in bytes. + */ +private[spark] +class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream { + + private val chunks = new ArrayBuffer[Array[Byte]] + + /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ + private var lastChunkIndex = -1 + + /** + * Next position to write in the last chunk. + * + * If this equals chunkSize, it means for next write we need to allocate a new chunk. + * This can also never be 0. + */ + private var position = chunkSize + + override def write(b: Int): Unit = { + allocateNewChunkIfNeeded() + chunks(lastChunkIndex)(position) = b.toByte + position += 1 + } + + override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + var written = 0 + while (written < len) { + allocateNewChunkIfNeeded() + val thisBatch = math.min(chunkSize - position, len - written) + System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch) + written += thisBatch + position += thisBatch + } + } + + @inline + private def allocateNewChunkIfNeeded(): Unit = { + if (position == chunkSize) { + chunks += new Array[Byte](chunkSize) + lastChunkIndex += 1 + position = 0 + } + } + + def toArrays: Array[Array[Byte]] = { + if (lastChunkIndex == -1) { + new Array[Array[Byte]](0) + } else { + // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. + // An alternative would have been returning an array of ByteBuffers, with the last buffer + // bounded to only the last chunk's position. However, given our use case in Spark (to put + // the chunks in block manager), only limiting the view bound of the buffer would still + // require the block manager to store the whole chunk. + val ret = new Array[Array[Byte]](chunks.size) + for (i <- 0 until chunks.size - 1) { + ret(i) = chunks(i) + } + if (position == chunkSize) { + ret(lastChunkIndex) = chunks(lastChunkIndex) + } else { + ret(lastChunkIndex) = new Array[Byte](position) + System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position) + } + ret + } + } +} diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e1c13de04a0be..b8574dfb42e6b 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -29,19 +29,14 @@ import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import com.google.common.collect.Sets; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.DefaultCodec; -import org.apache.hadoop.mapred.FileSplit; -import org.apache.hadoop.mapred.InputSplit; import org.apache.hadoop.mapred.SequenceFileInputFormat; import org.apache.hadoop.mapred.SequenceFileOutputFormat; -import org.apache.hadoop.mapred.TextInputFormat; import org.apache.hadoop.mapreduce.Job; import org.junit.After; import org.junit.Assert; @@ -49,7 +44,6 @@ import org.junit.Test; import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaHadoopRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -189,6 +183,36 @@ public void sortByKey() { Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); } + @Test + public void repartitionAndSortWithinPartitions() { + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(0, 5)); + pairs.add(new Tuple2(3, 8)); + pairs.add(new Tuple2(2, 6)); + pairs.add(new Tuple2(0, 8)); + pairs.add(new Tuple2(3, 8)); + pairs.add(new Tuple2(1, 3)); + + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + Partitioner partitioner = new Partitioner() { + public int numPartitions() { + return 2; + } + public int getPartition(Object key) { + return ((Integer)key).intValue() % 2; + } + }; + + JavaPairRDD repartitioned = + rdd.repartitionAndSortWithinPartitions(partitioner); + List>> partitions = repartitioned.glom().collect(); + Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), + new Tuple2(0, 8), new Tuple2(2, 6))); + Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3), + new Tuple2(3, 8), new Tuple2(3, 8))); + } + @Test public void emptyRDD() { JavaRDD rdd = sc.emptyRDD(); @@ -1283,23 +1307,4 @@ public void collectUnderlyingScalaRDD() { SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect(); Assert.assertEquals(data.size(), collected.length); } - - public void getHadoopInputSplits() { - String outDir = new File(tempDir, "output").getAbsolutePath(); - sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2).saveAsTextFile(outDir); - - JavaHadoopRDD hadoopRDD = (JavaHadoopRDD) - sc.hadoopFile(outDir, TextInputFormat.class, LongWritable.class, Text.class); - List inputPaths = hadoopRDD.mapPartitionsWithInputSplit( - new Function2>, Iterator>() { - @Override - public Iterator call(InputSplit split, Iterator> it) - throws Exception { - FileSplit fileSplit = (FileSplit) split; - return Lists.newArrayList(fileSplit.getPath().toUri().getPath()).iterator(); - } - }, true).collect(); - Assert.assertEquals(Sets.newHashSet(inputPaths), - Sets.newHashSet(outDir + "/part-00000", outDir + "/part-00001")); - } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 9c5f394d3899d..90dcadcffd091 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -32,6 +32,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar var split: Partition = _ /** An RDD which returns the values [1, 2, 3, 4]. */ var rdd: RDD[Int] = _ + var rdd2: RDD[Int] = _ + var rdd3: RDD[Int] = _ before { sc = new SparkContext("local", "test") @@ -43,6 +45,16 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar override val getDependencies = List[Dependency[_]]() override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator } + rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) { + override def getPartitions: Array[Partition] = firstParent[Int].partitions + override def compute(split: Partition, context: TaskContext) = + firstParent[Int].iterator(split, context) + }.cache() + rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) { + override def getPartitions: Array[Partition] = firstParent[Int].partitions + override def compute(split: Partition, context: TaskContext) = + firstParent[Int].iterator(split, context) + }.cache() } after { @@ -87,4 +99,11 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar assert(value.toList === List(1, 2, 3, 4)) } } + + test("verify task metrics updated correctly") { + cacheManager = sc.env.cacheManager + val context = new TaskContext(0, 0, 0) + cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) + assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) + } } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 4bc4346c0a288..2e3fc5ef0e336 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -21,7 +21,6 @@ import java.lang.ref.WeakReference import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials -import scala.language.postfixOps import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} @@ -52,6 +51,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha .setMaster("local[2]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") .set("spark.shuffle.manager", shuffleManager.getName) before { @@ -243,6 +243,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { .setMaster("local-cluster[2, 1, 512]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") .set("spark.shuffle.manager", shuffleManager.getName) sc = new SparkContext(conf2) @@ -319,6 +320,7 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor .setMaster("local-cluster[2, 1, 512]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") .set("spark.shuffle.manager", shuffleManager.getName) sc = new SparkContext(conf2) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 41c294f727b3c..81b64c36ddca1 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -24,8 +24,7 @@ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} import org.apache.spark.SparkContext._ -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} @@ -136,7 +135,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter sc.parallelize(1 to 10, 2).foreach { x => if (x == 1) System.exit(42) } } assert(thrown.getClass === classOf[SparkException]) - System.out.println(thrown.getMessage) assert(thrown.getMessage.contains("failed 4 times")) } } @@ -202,12 +200,13 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray val blockId = blockIds(0) val blockManager = SparkEnv.get.blockManager - blockManager.master.getLocations(blockId).foreach(id => { - val bytes = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(id.host, id.port)) - val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList + val blockTransfer = SparkEnv.get.blockTransferService + blockManager.master.getLocations(blockId).foreach { cmId => + val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, blockId.toString) + val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer()) + .asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) - }) + } } test("compute without caching when no partitions fit in memory") { diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index a73e1ef0288a5..5265ba904032f 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark import java.io.File -import org.apache.log4j.Logger -import org.apache.log4j.Level - import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.prop.TableDrivenPropertyChecks._ @@ -29,8 +26,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils -import scala.language.postfixOps - class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { @@ -54,7 +49,7 @@ class DriverSuite extends FunSuite with Timeouts { */ object DriverWithoutCleanup { def main(args: Array[String]) { - Logger.getRootLogger().setLevel(Level.WARN) + Utils.configTestLog4j("INFO") val sc = new SparkContext(args(0), "DriverWithoutCleanup") sc.parallelize(1 to 100, 4).count() } diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala new file mode 100644 index 0000000000000..2acc02a54fa3d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.BeforeAndAfterAll + +class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with hash-based shuffle. + + override def beforeAll() { + System.setProperty("spark.shuffle.manager", "hash") + } + + override def afterAll() { + System.clearProperty("spark.shuffle.manager") + } +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 9702838085627..5369169811f81 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -69,13 +69,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000), - (BlockManagerId("b", "hostB", 1000, 0), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), + (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() } @@ -86,9 +86,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getServerStatuses(10, 0).nonEmpty) @@ -105,14 +105,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simultaneous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -145,13 +145,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + Seq((BlockManagerId("a", "hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } @@ -174,7 +174,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Byte](10)(0))) masterActor.receive(GetMapOutputStatuses(10)) } @@ -195,7 +195,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new MapStatus( - BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Byte](4000000)(0))) } intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index b13ddf96bc77c..15aa4d83800fa 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.MutablePair -class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { +abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val conf = new SparkConf(loadDefaults = false) diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 5c02c00586ef4..639e56c488db4 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -24,8 +24,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. override def beforeAll() { - System.setProperty("spark.shuffle.manager", - "org.apache.spark.shuffle.sort.SortShuffleManager") + System.setProperty("spark.shuffle.manager", "sort") } override def afterAll() { diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 31aa7ec837f43..2a58c6a40d8e4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -121,8 +121,8 @@ class JsonProtocolSuite extends FunSuite { new SparkConf, ExecutorState.RUNNING) } def createDriverRunner(): DriverRunner = { - new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(), - null, "akka://worker") + new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), + createDriverDesc(), null, "akka://worker") } def assertValidJson(json: JValue) { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7e1ef80c84561..22b369a829418 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -317,6 +317,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { object JarCreationTest { def main(args: Array[String]) { + Utils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => @@ -338,6 +339,7 @@ object JarCreationTest { object SimpleApplicationTest { def main(args: Array[String]) { + Utils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val configs = Seq("spark.master", "spark.app.name") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index c930839b47f11..b6f4411e0587a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -25,14 +25,15 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.FunSuite +import org.apache.spark.SparkConf import org.apache.spark.deploy.{Command, DriverDescription} class DriverRunnerTest extends FunSuite { private def createDriverRunner() = { val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) - new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription, - null, "akka://1.2.3.4/worker/") + new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), + driverDescription, null, "akka://1.2.3.4/worker/") } private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = { diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala similarity index 97% rename from core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala rename to core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index e2f4d4c57cdb5..9f49587cdc670 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -15,23 +15,18 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.nio import java.io.IOException import java.nio._ -import java.util.concurrent.TimeoutException -import org.apache.spark.{SecurityManager, SparkConf} -import org.scalatest.FunSuite - -import org.mockito.Mockito._ -import org.mockito.Matchers._ - -import scala.concurrent.TimeoutException -import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ +import scala.concurrent.{Await, TimeoutException} import scala.language.postfixOps -import scala.util.{Failure, Success, Try} + +import org.scalatest.FunSuite + +import org.apache.spark.{SecurityManager, SparkConf} /** * Test the ConnectionManager with various security settings. diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 28197657e9bad..3b833f2e41867 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -22,7 +22,6 @@ import java.util.concurrent.Semaphore import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global -import scala.language.postfixOps import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Timeouts diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala index 956c2b9cbd321..8408d7e785c65 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -38,9 +38,7 @@ class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { Iterator() } } - val prunedRDD = PartitionPruningRDD.create(rdd, { - x => if (x == 2) true else false - }) + val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2) assert(prunedRDD.partitions.length == 1) val p = prunedRDD.partitions(0) assert(p.index == 0) @@ -62,13 +60,10 @@ class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { List(split.asInstanceOf[TestPartition].testValue).iterator } } - val prunedRDD1 = PartitionPruningRDD.create(rdd, { - x => if (x == 0) true else false - }) + val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0) - val prunedRDD2 = PartitionPruningRDD.create(rdd, { - x => if (x == 2) true else false - }) + + val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2) val merged = prunedRDD1 ++ prunedRDD2 assert(merged.count() == 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 926d4fecb5b91..c1b501a75c8b8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -521,6 +521,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sortedLowerK === Array(1, 2, 3, 4, 5)) } + test("takeOrdered with limit 0") { + val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val rdd = sc.makeRDD(nums, 2) + val sortedLowerK = rdd.takeOrdered(0) + assert(sortedLowerK.size === 0) + } + test("takeOrdered with custom ordering") { val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) implicit val ord = implicitly[Ordering[Int]].reverse @@ -675,6 +682,20 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered) } + test("repartitionAndSortWithinPartitions") { + val data = sc.parallelize(Seq((0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)), 2) + + val partitioner = new Partitioner { + def numPartitions: Int = 2 + def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2 + } + + val repartitioned = data.repartitionAndSortWithinPartitions(partitioner) + val partitions = repartitioned.glom().collect() + assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6))) + assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8))) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bd829752eb401..aa73469b6acd8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import scala.collection.mutable.{HashSet, HashMap, Map} +import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import akka.actor._ @@ -27,6 +27,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} @@ -97,10 +98,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val sparkListener = new SparkListener() { - val successfulStages = new HashSet[Int]() - val failedStages = new HashSet[Int]() + val successfulStages = new HashSet[Int] + val failedStages = new ArrayBuffer[Int] + val stageByOrderOfExecution = new ArrayBuffer[Int] override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { val stageInfo = stageCompleted.stageInfo + stageByOrderOfExecution += stageInfo.stageId if (stageInfo.failureReason.isEmpty) { successfulStages += stageInfo.stageId } else { @@ -120,7 +123,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F */ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations - val blockManagerMaster = new BlockManagerMaster(null, conf) { + val blockManagerMaster = new BlockManagerMaster(null, conf, true) { override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { blockIds.map { _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). @@ -231,6 +234,13 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runEvent(JobCancelled(jobId)) } + test("[SPARK-3353] parent stage should have lower stage id") { + sparkListener.stageByOrderOfExecution.clear() + sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count() + assert(sparkListener.stageByOrderOfExecution.length === 2) + assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1)) + } + test("zero split job") { var numResults = 0 val fakeListener = new JobListener() { @@ -435,6 +445,43 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } + test("trivial shuffle with multiple fetch failures") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + // The MapOutputTracker should know about both map output locations. + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === + Array("hostA", "hostB")) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), + null, + Map[Long, Any](), + null, + null)) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(1)) + + // The second ResultTask fails, with a fetch failure for the output from the second mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1), + null, + Map[Long, Any](), + null, + null)) + // The SparkListener should not receive redundant failure events. + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.size == 1) + } + test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -478,8 +525,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // Listener bus should get told about the map stage failing, but not the reduce stage // (since the reduce stage hasn't been started yet). assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) - assert(sparkListener.failedStages.contains(1)) - assert(sparkListener.failedStages.size === 1) + assert(sparkListener.failedStages.toSet === Set(0)) assertDataStructuresEmpty } @@ -526,14 +572,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F val stageFailureMessage = "Exception failure in map stage" failed(taskSets(0), stageFailureMessage) - assert(cancelledStages.contains(1)) + assert(cancelledStages.toSet === Set(0, 2)) // Make sure the listeners got told about both failed stages. assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.successfulStages.isEmpty) - assert(sparkListener.failedStages.contains(1)) - assert(sparkListener.failedStages.contains(3)) - assert(sparkListener.failedStages.size === 2) + assert(sparkListener.failedStages.toSet === Set(0, 2)) assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") @@ -699,7 +743,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) private def makeBlockManagerId(host: String): BlockManagerId = - BlockManagerId("exec-" + host, host, 12345, 0) + BlockManagerId("exec-" + host, host, 12345) private def assertDataStructuresEmpty = { assert(scheduler.activeJobs.isEmpty) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 10d8b299317ea..e5315bc93e217 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -26,7 +26,9 @@ import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec +import org.apache.spark.SPARK_VERSION import org.apache.spark.util.{JsonProtocol, Utils} import java.io.File @@ -39,7 +41,8 @@ import java.io.File * read and deserialized into actual SparkListenerEvents. */ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { - private val fileSystem = Utils.getHadoopFileSystem("/") + private val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) private val allCompressionCodecs = Seq[String]( "org.apache.spark.io.LZFCompressionCodec", "org.apache.spark.io.SnappyCompressionCodec" @@ -194,7 +197,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { def assertInfoCorrect(info: EventLoggingInfo, loggerStopped: Boolean) { assert(info.logPaths.size > 0) - assert(info.sparkVersion === SparkContext.SPARK_VERSION) + assert(info.sparkVersion === SPARK_VERSION) assert(info.compressionCodec.isDefined === compressionCodec.isDefined) info.compressionCodec.foreach { codec => assert(compressionCodec.isDefined) @@ -227,7 +230,8 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { val conf = getLoggingConf(logDirPath, compressionCodec) val eventLogger = new EventLoggingListener("test", conf) val listenerBus = new LiveListenerBus - val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", 125L, "Mickey") + val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, + 125L, "Mickey") val applicationEnd = SparkListenerApplicationEnd(1000L) // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite @@ -378,7 +382,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { private def assertSparkVersionIsValid(logFiles: Array[FileStatus]) { val file = logFiles.map(_.getPath.getName).find(EventLoggingListener.isSparkVersionFile) assert(file.isDefined) - assert(EventLoggingListener.parseSparkVersion(file.get) === SparkContext.SPARK_VERSION) + assert(EventLoggingListener.parseSparkVersion(file.get) === SPARK_VERSION) } private def assertCompressionCodecIsValid(logFiles: Array[FileStatus], compressionCodec: String) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index d81499ac6abef..7ab351d1b4d24 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -32,11 +33,9 @@ import org.apache.spark.util.{JsonProtocol, Utils} * Test whether ReplayListenerBus replays events from logs correctly. */ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { - private val fileSystem = Utils.getHadoopFileSystem("/") - private val allCompressionCodecs = Seq[String]( - "org.apache.spark.io.LZFCompressionCodec", - "org.apache.spark.io.SnappyCompressionCodec" - ) + private val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) + private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS private var testDir: File = _ before { @@ -84,7 +83,8 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { val fstream = fileSystem.create(logFilePath) val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream) val writer = new PrintWriter(cstream) - val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", 125L, "Mickey") + val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, + 125L, "Mickey") val applicationEnd = SparkListenerApplicationEnd(1000L) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 3b0b8e2f68c97..ab35e8edc4ebf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -180,7 +180,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers rdd3.count() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {2} // Shuffle map stage + result stage - val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get + val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get stageInfo3.rddInfos.size should be {1} // ShuffledRDD stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true} stageInfo3.rddInfos.exists(_.name == "Trois") should be {true} diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala new file mode 100644 index 0000000000000..ba47fe5e25b9b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import java.io.{File, FileWriter} + +import scala.language.reflectiveCalls + +import org.scalatest.FunSuite + +import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FileShuffleBlockManager +import org.apache.spark.storage.{ShuffleBlockId, FileSegment} + +class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { + private val testConf = new SparkConf(false) + + private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { + assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) + val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] + assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath) + assert(expected.offset === segment.offset) + assert(expected.length === segment.length) + } + + test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { + + val conf = new SparkConf(false) + // reset after EACH object write. This is to ensure that there are bytes appended after + // an object is written. So if the codepaths assume writeObject is end of data, this should + // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + + sc = new SparkContext("local", "test", conf) + + val shuffleBlockManager = + SparkEnv.get.shuffleManager.shuffleBlockManager.asInstanceOf[FileShuffleBlockManager] + + val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf), + new ShuffleWriteMetrics) + for (writer <- shuffle1.writers) { + writer.write("test1") + writer.write("test2") + } + for (writer <- shuffle1.writers) { + writer.commitAndClose() + } + + val shuffle1Segment = shuffle1.writers(0).fileSegment() + shuffle1.releaseWriters(success = true) + + val shuffle2 = shuffleBlockManager.forMapTask(1, 2, 1, new JavaSerializer(conf), + new ShuffleWriteMetrics) + + for (writer <- shuffle2.writers) { + writer.write("test3") + writer.write("test4") + } + for (writer <- shuffle2.writers) { + writer.commitAndClose() + } + val shuffle2Segment = shuffle2.writers(0).fileSegment() + shuffle2.releaseWriters(success = true) + + // Now comes the test : + // Write to shuffle 3; and close it, but before registering it, check if the file lengths for + // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length + // of block based on remaining data in file : which could mess things up when there is concurrent read + // and writes happening to the same shuffle group. + + val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), + new ShuffleWriteMetrics) + for (writer <- shuffle3.writers) { + writer.write("test3") + writer.write("test4") + } + for (writer <- shuffle3.writers) { + writer.commitAndClose() + } + // check before we register. + checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0))) + shuffle3.releaseWriters(success = true) + checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0))) + shuffleBlockManager.removeShuffle(1) + } + + def writeToFile(file: File, numBytes: Int) { + val writer = new FileWriter(file, true) + for (i <- 0 until numBytes) writer.write(i) + writer.close() + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala deleted file mode 100644 index bcbfe8baf36ad..0000000000000 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.io.IOException -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.future -import scala.concurrent.ExecutionContext.Implicits.global - -import org.scalatest.{FunSuite, Matchers} - -import org.mockito.Mockito._ -import org.mockito.Matchers.{any, eq => meq} -import org.mockito.stubbing.Answer -import org.mockito.invocation.InvocationOnMock - -import org.apache.spark.storage.BlockFetcherIterator._ -import org.apache.spark.network.{ConnectionManager, Message} -import org.apache.spark.executor.ShuffleReadMetrics - -class BlockFetcherIteratorSuite extends FunSuite with Matchers { - - test("block fetch from local fails using BasicBlockFetcherIterator") { - val blockManager = mock(classOf[BlockManager]) - val connManager = mock(classOf[ConnectionManager]) - doReturn(connManager).when(blockManager).connectionManager - doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId - - doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight - - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) - val answer = new Answer[Option[Iterator[Any]]] { - override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { - throw new Exception - } - } - - // 3rd block is going to fail - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) - doAnswer(answer).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) - - val bmId = BlockManagerId("test-client", "test-client",1 , 0) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) - ) - - val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, - new ShuffleReadMetrics()) - - iterator.initialize() - - // 3rd getLocalFromDisk invocation should be failed - verify(blockManager, times(3)).getLocalFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully - assert(iterator.next._2.isDefined, "1st element should be defined but is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next._2.isDefined, "2nd element should be defined but is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - // 3rd fetch should be failed - assert(!iterator.next._2.isDefined, "3rd element should not be defined but is actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") - // Don't call next() after fetching non-defined element even if thare are rest of elements in the iterator. - // Otherwise, BasicBlockFetcherIterator hangs up. - } - - - test("block fetch from local succeed using BasicBlockFetcherIterator") { - val blockManager = mock(classOf[BlockManager]) - val connManager = mock(classOf[ConnectionManager]) - doReturn(connManager).when(blockManager).connectionManager - doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId - - doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight - - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) - - // All blocks should be fetched successfully - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) - - val bmId = BlockManagerId("test-client", "test-client",1 , 0) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) - ) - - val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, - new ShuffleReadMetrics()) - - iterator.initialize() - - // getLocalFromDis should be invoked for all of 5 blocks - verify(blockManager, times(5)).getLocalFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next._2.isDefined, "All elements should be defined but 2nd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - assert(iterator.next._2.isDefined, "All elements should be defined but 3rd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") - assert(iterator.next._2.isDefined, "All elements should be defined but 4th element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") - assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") - } - - test("block fetch from remote fails using BasicBlockFetcherIterator") { - val blockManager = mock(classOf[BlockManager]) - val connManager = mock(classOf[ConnectionManager]) - when(blockManager.connectionManager).thenReturn(connManager) - - val f = future { - throw new IOException("Send failed or we received an error ACK") - } - when(connManager.sendMessageReliably(any(), - any())).thenReturn(f) - when(blockManager.futureExecContext).thenReturn(global) - - when(blockManager.blockManagerId).thenReturn( - BlockManagerId("test-client", "test-client", 1, 0)) - when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) - - val blId1 = ShuffleBlockId(0,0,0) - val blId2 = ShuffleBlockId(0,1,0) - val bmId = BlockManagerId("test-server", "test-server",1 , 0) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((blId1, 1L), (blId2, 1L))) - ) - - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null, new ShuffleReadMetrics()) - - iterator.initialize() - iterator.foreach{ - case (_, r) => { - (!r.isDefined) should be(true) - } - } - } - - test("block fetch from remote succeed using BasicBlockFetcherIterator") { - val blockManager = mock(classOf[BlockManager]) - val connManager = mock(classOf[ConnectionManager]) - when(blockManager.connectionManager).thenReturn(connManager) - - val blId1 = ShuffleBlockId(0,0,0) - val blId2 = ShuffleBlockId(0,1,0) - val buf1 = ByteBuffer.allocate(4) - val buf2 = ByteBuffer.allocate(4) - buf1.putInt(1) - buf1.flip() - buf2.putInt(1) - buf2.flip() - val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1)) - val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2)) - val blockMessageArray = new BlockMessageArray( - Seq(blockMessage1, blockMessage2)) - - val bufferMessage = blockMessageArray.toBufferMessage - val buffer = ByteBuffer.allocate(bufferMessage.size) - val arrayBuffer = new ArrayBuffer[ByteBuffer] - bufferMessage.buffers.foreach{ b => - buffer.put(b) - } - buffer.flip() - arrayBuffer += buffer - - val f = future { - Message.createBufferMessage(arrayBuffer) - } - when(connManager.sendMessageReliably(any(), - any())).thenReturn(f) - when(blockManager.futureExecContext).thenReturn(global) - - when(blockManager.blockManagerId).thenReturn( - BlockManagerId("test-client", "test-client", 1, 0)) - when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) - - val bmId = BlockManagerId("test-server", "test-server",1 , 0) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((blId1, 1L), (blId2, 1L))) - ) - - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null, new ShuffleReadMetrics()) - iterator.initialize() - iterator.foreach{ - case (_, r) => { - (r.isDefined) should be(true) - } - } - } -} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f32ce6f9fcc7f..e251660dae5de 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,15 +21,19 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import java.util.concurrent.TimeUnit +import org.apache.spark.network.nio.NioBlockTransferService + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + import akka.actor._ import akka.pattern.ask import akka.util.Timeout -import org.apache.spark.shuffle.hash.HashShuffleManager -import org.mockito.invocation.InvocationOnMock -import org.mockito.Matchers.any -import org.mockito.Mockito.{doAnswer, mock, spy, when} -import org.mockito.stubbing.Answer +import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ @@ -38,17 +42,12 @@ import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod -import org.apache.spark.network.{Message, ConnectionManagerId} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.implicitConversions -import scala.language.postfixOps class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter with PrivateMethodTester { @@ -73,8 +72,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { - new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr, - mapOutputTracker, shuffleManager) + val transfer = new NioBlockTransferService(conf, securityMgr) + new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer) } before { @@ -92,7 +92,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter master = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf) + conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -139,9 +139,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("BlockManagerId object caching") { - val id1 = BlockManagerId("e1", "XXX", 1, 0) - val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1 - val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object + val id1 = BlockManagerId("e1", "XXX", 1) + val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1 + val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") assert(id3 != id1, "id3 is same as id1") @@ -792,8 +792,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. + val transfer = new NioBlockTransferService(conf, securityMgr) store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr, mapOutputTracker, shuffleManager) + mapOutputTracker, shuffleManager, transfer) // The put should fail since a1 is not serializable. class UnserializableClass @@ -823,11 +824,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // be nice to refactor classes involved in disk storage in a way that // allows for easier testing. val blockManager = mock(classOf[BlockManager]) - val shuffleBlockManager = mock(classOf[ShuffleBlockManager]) - when(shuffleBlockManager.conf).thenReturn(conf) - val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) - when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) + val diskBlockManager = new DiskBlockManager(blockManager, conf) + val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val mapped = diskStoreMapped.getBytes(blockId).get @@ -1006,109 +1005,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } - test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker, shuffleManager) - - val worker = spy(new BlockManagerWorker(store)) - val connManagerId = mock(classOf[ConnectionManagerId]) - - // setup request block messages - val reqBlId1 = ShuffleBlockId(0,0,0) - val reqBlId2 = ShuffleBlockId(0,1,0) - val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) - val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) - val reqBlockMessages = new BlockMessageArray( - Seq(reqBlockMessage1, reqBlockMessage2)) - val reqBufferMessage = reqBlockMessages.toBufferMessage - - val answer = new Answer[Option[BlockMessage]] { - override def answer(invocation: InvocationOnMock) - :Option[BlockMessage]= { - throw new Exception - } - } - - doAnswer(answer).when(worker).processBlockMessage(any()) - - // Test when exception was thrown during processing block messages - var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) - - assert(ackMessage.isDefined, "When Exception was thrown in " + - "BlockManagerWorker#processBlockMessage, " + - "ackMessage should be defined") - assert(ackMessage.get.hasError, "When Exception was thown in " + - "BlockManagerWorker#processBlockMessage, " + - "ackMessage should have error") - - val notBufferMessage = mock(classOf[Message]) - - // Test when not BufferMessage was received - ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId) - assert(ackMessage.isDefined, "When not BufferMessage was passed to " + - "BlockManagerWorker#onBlockMessageReceive, " + - "ackMessage should be defined") - assert(ackMessage.get.hasError, "When not BufferMessage was passed to " + - "BlockManagerWorker#onBlockMessageReceive, " + - "ackMessage should have error") - } - - test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker, shuffleManager) - - val worker = spy(new BlockManagerWorker(store)) - val connManagerId = mock(classOf[ConnectionManagerId]) - - // setup request block messages - val reqBlId1 = ShuffleBlockId(0,0,0) - val reqBlId2 = ShuffleBlockId(0,1,0) - val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) - val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) - val reqBlockMessages = new BlockMessageArray( - Seq(reqBlockMessage1, reqBlockMessage2)) - - val tmpBufferMessage = reqBlockMessages.toBufferMessage - val buffer = ByteBuffer.allocate(tmpBufferMessage.size) - val arrayBuffer = new ArrayBuffer[ByteBuffer] - tmpBufferMessage.buffers.foreach{ b => - buffer.put(b) - } - buffer.flip() - arrayBuffer += buffer - val reqBufferMessage = Message.createBufferMessage(arrayBuffer) - - // setup ack block messages - val buf1 = ByteBuffer.allocate(4) - val buf2 = ByteBuffer.allocate(4) - buf1.putInt(1) - buf1.flip() - buf2.putInt(1) - buf2.flip() - val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1)) - val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2)) - - val answer = new Answer[Option[BlockMessage]] { - override def answer(invocation: InvocationOnMock) - :Option[BlockMessage]= { - if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq( - reqBlockMessage1)) { - return Some(ackBlockMessage1) - } else { - return Some(ackBlockMessage2) - } - } - } - - doAnswer(answer).when(worker).processBlockMessage(any()) - - val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) - assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " + - "was executed successfully, ackMessage should be defined") - assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " + - "was executed successfully, ackMessage should not have error") - } - test("reserve/release unroll memory") { store = makeBlockManager(12000) val memoryStore = store.memoryStore diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index aabaeadd7a071..e4522e00a622d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.shuffle.hash.HashShuffleManager import scala.collection.mutable @@ -26,6 +27,7 @@ import scala.language.reflectiveCalls import akka.actor.Props import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf @@ -40,18 +42,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before private var rootDir1: File = _ private var rootDirs: String = _ - // This suite focuses primarily on consolidation features, - // so we coerce consolidation if not already enabled. - testConf.set("spark.shuffle.consolidateFiles", "true") - - private val shuffleManager = new HashShuffleManager(testConf.clone) - - val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) { - override def conf = testConf.clone - var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() - override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) - } - + val blockManager = mock(classOf[BlockManager]) + when(blockManager.conf).thenReturn(testConf) var diskBlockManager: DiskBlockManager = _ override def beforeAll() { @@ -61,7 +53,6 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before rootDir1 = Files.createTempDir() rootDir1.deleteOnExit() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath - println("Created root dirs: " + rootDirs) } override def afterAll() { @@ -73,22 +64,17 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before override def beforeEach() { val conf = testConf.clone conf.set("spark.local.dir", rootDirs) - diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) - shuffleBlockManager.idToSegmentMap.clear() + diskBlockManager = new DiskBlockManager(blockManager, conf) } override def afterEach() { diskBlockManager.stop() - shuffleBlockManager.idToSegmentMap.clear() } test("basic block creation") { val blockId = new TestBlockId("test") - assertSegmentEquals(blockId, blockId.name, 0, 0) - val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) - assertSegmentEquals(blockId, blockId.name, 0, 10) assert(diskBlockManager.containsBlock(blockId)) newFile.delete() assert(!diskBlockManager.containsBlock(blockId)) @@ -101,127 +87,6 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } - test("block appending") { - val blockId = new TestBlockId("test") - val newFile = diskBlockManager.getFile(blockId) - writeToFile(newFile, 15) - assertSegmentEquals(blockId, blockId.name, 0, 15) - val newFile2 = diskBlockManager.getFile(blockId) - assert(newFile === newFile2) - writeToFile(newFile2, 12) - assertSegmentEquals(blockId, blockId.name, 0, 27) - newFile.delete() - } - - test("block remapping") { - val filename = "test" - val blockId0 = new ShuffleBlockId(1, 2, 3) - val newFile = diskBlockManager.getFile(filename) - writeToFile(newFile, 15) - shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15) - assertSegmentEquals(blockId0, filename, 0, 15) - - val blockId1 = new ShuffleBlockId(1, 2, 4) - val newFile2 = diskBlockManager.getFile(filename) - writeToFile(newFile2, 12) - shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12) - assertSegmentEquals(blockId1, filename, 15, 12) - - assert(newFile === newFile2) - newFile.delete() - } - - private def checkSegments(segment1: FileSegment, segment2: FileSegment) { - assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath) - assert (segment1.offset === segment2.offset) - assert (segment1.length === segment2.length) - } - - test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { - - val serializer = new JavaSerializer(testConf) - val confCopy = testConf.clone - // reset after EACH object write. This is to ensure that there are bytes appended after - // an object is written. So if the codepaths assume writeObject is end of data, this should - // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. - confCopy.set("spark.serializer.objectStreamReset", "1") - - val securityManager = new org.apache.spark.SecurityManager(confCopy) - // Do not use the shuffleBlockManager above ! - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, confCopy, - securityManager) - val master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), - confCopy) - val store = new BlockManager("", actorSystem, master , serializer, confCopy, - securityManager, null, shuffleManager) - - try { - - val shuffleManager = store.shuffleBlockManager - - val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer, new ShuffleWriteMetrics) - for (writer <- shuffle1.writers) { - writer.write("test1") - writer.write("test2") - } - for (writer <- shuffle1.writers) { - writer.commitAndClose() - } - - val shuffle1Segment = shuffle1.writers(0).fileSegment() - shuffle1.releaseWriters(success = true) - - val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf), - new ShuffleWriteMetrics) - - for (writer <- shuffle2.writers) { - writer.write("test3") - writer.write("test4") - } - for (writer <- shuffle2.writers) { - writer.commitAndClose() - } - val shuffle2Segment = shuffle2.writers(0).fileSegment() - shuffle2.releaseWriters(success = true) - - // Now comes the test : - // Write to shuffle 3; and close it, but before registering it, check if the file lengths for - // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is concurrent read - // and writes happening to the same shuffle group. - - val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), - new ShuffleWriteMetrics) - for (writer <- shuffle3.writers) { - writer.write("test3") - writer.write("test4") - } - for (writer <- shuffle3.writers) { - writer.commitAndClose() - } - // check before we register. - checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0))) - shuffle3.releaseWriters(success = true) - checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0))) - shuffleManager.removeShuffle(1) - } finally { - - if (store != null) { - store.stop() - } - actorSystem.shutdown() - actorSystem.awaitTermination() - } - } - - def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) { - val segment = diskBlockManager.getBlockLocation(blockId) - assert(segment.file.getName === filename) - assert(segment.offset === offset) - assert(segment.length === length) - } - def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala new file mode 100644 index 0000000000000..809bd70929656 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.TaskContext +import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} + +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.scalatest.FunSuite + + +class ShuffleBlockFetcherIteratorSuite extends FunSuite { + + test("handle local read failures in BlockManager") { + val transfer = mock(classOf[BlockTransferService]) + val blockManager = mock(classOf[BlockManager]) + doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + val answer = new Answer[Option[Iterator[Any]]] { + override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { + throw new Exception + } + } + + // 3rd block is going to fail + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) + doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client", 1) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new ShuffleBlockFetcherIterator( + new TaskContext(0, 0, 0), + transfer, + blockManager, + blocksByAddress, + null, + 48 * 1024 * 1024) + + // Without exhausting the iterator, the iterator should be lazy and not call + // getLocalShuffleFromDisk. + verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + // the 2nd element of the tuple returned by iterator.next should be defined when + // fetching successfully + assert(iterator.next()._2.isDefined, + "1st element should be defined but is not actually defined") + verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next()._2.isDefined, + "2nd element should be defined but is not actually defined") + verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + // 3rd fetch should be failed + intercept[Exception] { + iterator.next() + } + verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any()) + } + + test("handle local read successes") { + val transfer = mock(classOf[BlockTransferService]) + val blockManager = mock(classOf[BlockManager]) + doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + + // All blocks should be fetched successfully + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client", 1) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new ShuffleBlockFetcherIterator( + new TaskContext(0, 0, 0), + transfer, + blockManager, + blocksByAddress, + null, + 48 * 1024 * 1024) + + // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk. + verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 1st element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 2nd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 3rd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 4th element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 5th element is not actually defined") + + verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any()) + } + + test("handle remote fetch failures in BlockTransferService") { + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + listener.onBlockFetchFailure(new Exception("blah")) + } + }) + + val blockManager = mock(classOf[BlockManager]) + + when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1)) + + val blId1 = ShuffleBlockId(0, 0, 0) + val blId2 = ShuffleBlockId(0, 1, 0) + val bmId = BlockManagerId("test-server", "test-server", 1) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1L), (blId2, 1L)))) + + val iterator = new ShuffleBlockFetcherIterator( + new TaskContext(0, 0, 0), + transfer, + blockManager, + blocksByAddress, + null, + 48 * 1024 * 1024) + + iterator.foreach { case (_, iterOption) => + assert(!iterOption.isDefined) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 51fb646a3cb61..3a45875391e29 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -26,8 +26,8 @@ import org.apache.spark.scheduler._ * Test the behavior of StorageStatusListener in response to all relevant events. */ class StorageStatusListenerSuite extends FunSuite { - private val bm1 = BlockManagerId("big", "dog", 1, 1) - private val bm2 = BlockManagerId("fat", "duck", 2, 2) + private val bm1 = BlockManagerId("big", "dog", 1) + private val bm2 = BlockManagerId("fat", "duck", 2) private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) private val taskInfo2 = new TaskInfo(0, 0, 0, 0, "fat", "duck", TaskLocality.ANY, false) @@ -36,13 +36,13 @@ class StorageStatusListenerSuite extends FunSuite { // Block manager add assert(listener.executorIdToStorageStatus.size === 0) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) assert(listener.executorIdToStorageStatus.size === 1) assert(listener.executorIdToStorageStatus.get("big").isDefined) assert(listener.executorIdToStorageStatus("big").blockManagerId === bm1) assert(listener.executorIdToStorageStatus("big").maxMem === 1000L) assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) assert(listener.executorIdToStorageStatus.size === 2) assert(listener.executorIdToStorageStatus.get("fat").isDefined) assert(listener.executorIdToStorageStatus("fat").blockManagerId === bm2) @@ -50,11 +50,11 @@ class StorageStatusListenerSuite extends FunSuite { assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) // Block manager remove - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(bm1)) + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm1)) assert(listener.executorIdToStorageStatus.size === 1) assert(!listener.executorIdToStorageStatus.get("big").isDefined) assert(listener.executorIdToStorageStatus.get("fat").isDefined) - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(bm2)) + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm2)) assert(listener.executorIdToStorageStatus.size === 0) assert(!listener.executorIdToStorageStatus.get("big").isDefined) assert(!listener.executorIdToStorageStatus.get("fat").isDefined) @@ -62,25 +62,25 @@ class StorageStatusListenerSuite extends FunSuite { test("task end without updated blocks") { val listener = new StorageStatusListener - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L)) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) val taskMetrics = new TaskMetrics // Task end with no updated blocks assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics)) assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics)) assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) } test("task end with updated blocks") { val listener = new StorageStatusListener - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L)) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) val taskMetrics1 = new TaskMetrics val taskMetrics2 = new TaskMetrics val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L)) @@ -92,13 +92,13 @@ class StorageStatusListenerSuite extends FunSuite { // Task end with new blocks assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) assert(listener.executorIdToStorageStatus("big").numBlocks === 2) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics2)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2)) assert(listener.executorIdToStorageStatus("big").numBlocks === 2) assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) @@ -111,13 +111,14 @@ class StorageStatusListenerSuite extends FunSuite { val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) taskMetrics1.updatedBlocks = Some(Seq(droppedBlock1, droppedBlock3)) taskMetrics2.updatedBlocks = Some(Seq(droppedBlock2, droppedBlock3)) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) + + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) assert(listener.executorIdToStorageStatus("big").numBlocks === 1) assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics2)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2)) assert(listener.executorIdToStorageStatus("big").numBlocks === 1) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) @@ -127,7 +128,7 @@ class StorageStatusListenerSuite extends FunSuite { test("unpersist RDD") { val listener = new StorageStatusListener - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) val taskMetrics1 = new TaskMetrics val taskMetrics2 = new TaskMetrics val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L)) @@ -135,8 +136,8 @@ class StorageStatusListenerSuite extends FunSuite { val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L, 0L)) taskMetrics1.updatedBlocks = Some(Seq(block1, block2)) taskMetrics2.updatedBlocks = Some(Seq(block3)) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) - listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics2)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics2)) assert(listener.executorIdToStorageStatus("big").numBlocks === 3) // Unpersist RDD diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index 38678bbd1dd28..ef5c55f91c39a 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -27,7 +27,7 @@ class StorageSuite extends FunSuite { // For testing add, update, and remove (for non-RDD blocks) private def storageStatus1: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) assert(status.blocks.isEmpty) assert(status.rddBlocks.isEmpty) assert(status.memUsed === 0L) @@ -78,7 +78,7 @@ class StorageSuite extends FunSuite { // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) assert(status.rddBlocks.isEmpty) status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L, 0L)) status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L, 0L)) @@ -271,9 +271,9 @@ class StorageSuite extends FunSuite { // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations private def stockStorageStatuses: Seq[StorageStatus] = { - val status1 = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L) - val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2, 2), 2000L) - val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3, 3), 3000L) + val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L) + val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L) status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L, 0L)) status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L, 0L)) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 038746d2eda4b..48790b59e7fbd 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -21,7 +21,6 @@ import java.net.ServerSocket import javax.servlet.http.HttpServletRequest import scala.io.Source -import scala.language.postfixOps import scala.util.{Failure, Success, Try} import org.eclipse.jetty.server.Server @@ -36,11 +35,25 @@ import scala.xml.Node class UISuite extends FunSuite { + /** + * Create a test SparkContext with the SparkUI enabled. + * It is safe to `get` the SparkUI directly from the SparkContext returned here. + */ + private def newSparkContext(): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val sc = new SparkContext(conf) + assert(sc.ui.isDefined) + sc + } + ignore("basic ui visibility") { - withSpark(new SparkContext("local", "test")) { sc => + withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress).mkString + val html = Source.fromURL(sc.ui.get.appUIAddress).mkString assert(!html.contains("random data that should not be present")) assert(html.toLowerCase.contains("stages")) assert(html.toLowerCase.contains("storage")) @@ -51,7 +64,7 @@ class UISuite extends FunSuite { } ignore("visibility at localhost:4040") { - withSpark(new SparkContext("local", "test")) { sc => + withSpark(newSparkContext()) { sc => // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString @@ -61,8 +74,8 @@ class UISuite extends FunSuite { } ignore("attaching a new tab") { - withSpark(new SparkContext("local", "test")) { sc => - val sparkUI = sc.ui + withSpark(newSparkContext()) { sc => + val sparkUI = sc.ui.get val newTab = new WebUITab(sparkUI, "foo") { attachPage(new WebUIPage("") { @@ -73,7 +86,7 @@ class UISuite extends FunSuite { } sparkUI.attachTab(newTab) eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress).mkString + val html = Source.fromURL(sparkUI.appUIAddress).mkString assert(!html.contains("random data that should not be present")) // check whether new page exists @@ -87,7 +100,7 @@ class UISuite extends FunSuite { } eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress.stripSuffix("/") + "/foo").mkString + val html = Source.fromURL(sparkUI.appUIAddress.stripSuffix("/") + "/foo").mkString // check whether new page exists assert(html.contains("magic")) } @@ -129,16 +142,20 @@ class UISuite extends FunSuite { } test("verify appUIAddress contains the scheme") { - withSpark(new SparkContext("local", "test")) { sc => - val uiAddress = sc.ui.appUIAddress - assert(uiAddress.equals("http://" + sc.ui.appUIHostPort)) + withSpark(newSparkContext()) { sc => + val ui = sc.ui.get + val uiAddress = ui.appUIAddress + val uiHostPort = ui.appUIHostPort + assert(uiAddress.equals("http://" + uiHostPort)) } } test("verify appUIAddress contains the port") { - withSpark(new SparkContext("local", "test")) { sc => - val splitUIAddress = sc.ui.appUIAddress.split(':') - assert(splitUIAddress(2).toInt == sc.ui.boundPort) + withSpark(newSparkContext()) { sc => + val ui = sc.ui.get + val splitUIAddress = ui.appUIAddress.split(':') + val boundPort = ui.boundPort + assert(splitUIAddress(2).toInt == boundPort) } } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 147ec0bc52e39..3370dd4156c3f 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -34,12 +34,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val listener = new JobProgressListener(conf) def createStageStartEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "") + val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") SparkListenerStageSubmitted(stageInfo) } def createStageEndEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "") + val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") SparkListenerStageCompleted(stageInfo) } @@ -70,33 +70,37 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo.finishTime = 1 var task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) - .shuffleRead === 1000) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) + assert(listener.stageIdToData.getOrElse((0, 0), fail()) + .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 1000) // finish a task with unknown executor-id, nothing should happen taskInfo = new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) taskInfo.finishTime = 1 task = new ShuffleMapTask(0) - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.size === 1) // finish this task, should get updated duration taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0) - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) - .shuffleRead === 2000) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) + assert(listener.stageIdToData.getOrElse((0, 0), fail()) + .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 2000) // finish this task, should get updated duration taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0) - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail()) - .shuffleRead === 1000) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) + assert(listener.stageIdToData.getOrElse((0, 0), fail()) + .executorSummary.getOrElse("exe-2", fail()).shuffleRead === 1000) } test("test task success vs failure counting for different task end reasons") { @@ -119,16 +123,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc UnknownReason) var failCount = 0 for (reason <- taskFailedReasons) { - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, reason, taskInfo, metrics)) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, reason, taskInfo, metrics)) failCount += 1 - assert(listener.stageIdToData(task.stageId).numCompleteTasks === 0) - assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) + assert(listener.stageIdToData((task.stageId, 0)).numCompleteTasks === 0) + assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) } // Make sure we count success as success. - listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, metrics)) - assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1) - assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 1, taskType, Success, taskInfo, metrics)) + assert(listener.stageIdToData((task.stageId, 1)).numCompleteTasks === 1) + assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) } test("test update metrics") { @@ -163,18 +169,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo } - listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1234L))) - listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1235L))) - listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1236L))) - listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1237L))) + listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1234L))) + listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1235L))) + listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1236L))) + listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( - (1234L, 0, makeTaskMetrics(0)), - (1235L, 0, makeTaskMetrics(100)), - (1236L, 1, makeTaskMetrics(200))))) + (1234L, 0, 0, makeTaskMetrics(0)), + (1235L, 0, 0, makeTaskMetrics(100)), + (1236L, 1, 0, makeTaskMetrics(200))))) - var stage0Data = listener.stageIdToData.get(0).get - var stage1Data = listener.stageIdToData.get(1).get + var stage0Data = listener.stageIdToData.get((0, 0)).get + var stage1Data = listener.stageIdToData.get((1, 0)).get assert(stage0Data.shuffleReadBytes == 102) assert(stage1Data.shuffleReadBytes == 201) assert(stage0Data.shuffleWriteBytes == 106) @@ -195,14 +201,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc .totalBlocksFetched == 202) // task that was included in a heartbeat - listener.onTaskEnd(SparkListenerTaskEnd(0, taskType, Success, makeTaskInfo(1234L, 1), + listener.onTaskEnd(SparkListenerTaskEnd(0, 0, taskType, Success, makeTaskInfo(1234L, 1), makeTaskMetrics(300))) // task that wasn't included in a heartbeat - listener.onTaskEnd(SparkListenerTaskEnd(1, taskType, Success, makeTaskInfo(1237L, 1), + listener.onTaskEnd(SparkListenerTaskEnd(1, 0, taskType, Success, makeTaskInfo(1237L, 1), makeTaskMetrics(400))) - stage0Data = listener.stageIdToData.get(0).get - stage1Data = listener.stageIdToData.get(1).get + stage0Data = listener.stageIdToData.get((0, 0)).get + stage1Data = listener.stageIdToData.get((1, 0)).get assert(stage0Data.shuffleReadBytes == 402) assert(stage1Data.shuffleReadBytes == 602) assert(stage0Data.shuffleWriteBytes == 406) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 6e68dcb3425aa..e1bc1379b5d80 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -34,11 +34,12 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { private val memOnly = StorageLevel.MEMORY_ONLY private val none = StorageLevel.NONE private val taskInfo = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) + private val taskInfo1 = new TaskInfo(1, 1, 1, 1, "big", "cat", TaskLocality.ANY, false) private def rddInfo0 = new RDDInfo(0, "freedom", 100, memOnly) private def rddInfo1 = new RDDInfo(1, "hostage", 200, memOnly) private def rddInfo2 = new RDDInfo(2, "sanity", 300, memAndDisk) private def rddInfo3 = new RDDInfo(3, "grace", 400, memAndDisk) - private val bm1 = BlockManagerId("big", "dog", 1, 1) + private val bm1 = BlockManagerId("big", "dog", 1) before { bus = new LiveListenerBus @@ -53,7 +54,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { assert(storageListener.rddInfoList.isEmpty) // 2 RDDs are known, but none are cached - val stageInfo0 = new StageInfo(0, "0", 100, Seq(rddInfo0, rddInfo1), "details") + val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(rddInfo0, rddInfo1), "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 2) assert(storageListener.rddInfoList.isEmpty) @@ -63,7 +64,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val rddInfo3Cached = rddInfo3 rddInfo2Cached.numCachedPartitions = 1 rddInfo3Cached.numCachedPartitions = 1 - val stageInfo1 = new StageInfo(1, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), "details") + val stageInfo1 = new StageInfo(1, 0, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) assert(storageListener._rddInfoMap.size === 4) assert(storageListener.rddInfoList.size === 2) @@ -71,7 +72,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { // Submitting RDDInfos with duplicate IDs does nothing val rddInfo0Cached = new RDDInfo(0, "freedom", 100, StorageLevel.MEMORY_ONLY) rddInfo0Cached.numCachedPartitions = 1 - val stageInfo0Cached = new StageInfo(0, "0", 100, Seq(rddInfo0), "details") + val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0), "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo0Cached)) assert(storageListener._rddInfoMap.size === 4) assert(storageListener.rddInfoList.size === 2) @@ -87,7 +88,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val rddInfo1Cached = rddInfo1 rddInfo0Cached.numCachedPartitions = 1 rddInfo1Cached.numCachedPartitions = 1 - val stageInfo0 = new StageInfo(0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), "details") + val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 2) assert(storageListener.rddInfoList.size === 2) @@ -106,8 +107,8 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val myRddInfo0 = rddInfo0 val myRddInfo1 = rddInfo1 val myRddInfo2 = rddInfo2 - val stageInfo0 = new StageInfo(0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") - bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L)) + val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") + bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 3) assert(storageListener.rddInfoList.size === 0) // not cached @@ -116,7 +117,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { assert(!storageListener._rddInfoMap(2).isCached) // Task end with no updated blocks. This should not change anything. - bus.postToAll(SparkListenerTaskEnd(0, "obliteration", Success, taskInfo, new TaskMetrics)) + bus.postToAll(SparkListenerTaskEnd(0, 0, "obliteration", Success, taskInfo, new TaskMetrics)) assert(storageListener._rddInfoMap.size === 3) assert(storageListener.rddInfoList.size === 0) @@ -128,7 +129,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { (RDDBlockId(0, 102), BlockStatus(memAndDisk, 400L, 0L, 200L)), (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L, 0L)) )) - bus.postToAll(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo, metrics1)) + bus.postToAll(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo, metrics1)) assert(storageListener._rddInfoMap(0).memSize === 800L) assert(storageListener._rddInfoMap(0).diskSize === 400L) assert(storageListener._rddInfoMap(0).tachyonSize === 200L) @@ -150,7 +151,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L, 0L)), // doesn't actually exist (RDDBlockId(4, 80), BlockStatus(none, 0L, 0L, 0L)) // doesn't actually exist )) - bus.postToAll(SparkListenerTaskEnd(2, "obliteration", Success, taskInfo, metrics2)) + bus.postToAll(SparkListenerTaskEnd(2, 0, "obliteration", Success, taskInfo, metrics2)) assert(storageListener._rddInfoMap(0).memSize === 400L) assert(storageListener._rddInfoMap(0).diskSize === 400L) assert(storageListener._rddInfoMap(0).tachyonSize === 200L) @@ -162,4 +163,30 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) } + test("verify StorageTab contains all cached rdds") { + + val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly) + val rddInfo1 = new RDDInfo(1, "rdd1", 1 ,memOnly) + val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), "details") + val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), "details") + val taskMetrics0 = new TaskMetrics + val taskMetrics1 = new TaskMetrics + val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L, 0L)) + val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L, 0L)) + taskMetrics0.updatedBlocks = Some(Seq(block0)) + taskMetrics1.updatedBlocks = Some(Seq(block1)) + bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) + bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) + assert(storageListener.rddInfoList.size === 0) + bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0)) + assert(storageListener.rddInfoList.size === 1) + bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + assert(storageListener.rddInfoList.size === 1) + bus.postToAll(SparkListenerStageCompleted(stageInfo0)) + assert(storageListener.rddInfoList.size === 1) + bus.postToAll(SparkListenerTaskEnd(1, 0, "small", Success, taskInfo1, taskMetrics1)) + assert(storageListener.rddInfoList.size === 2) + bus.postToAll(SparkListenerStageCompleted(stageInfo1)) + assert(storageListener.rddInfoList.size === 2) + } } diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index c4765e53de17b..76bf4cfd11267 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,13 +17,16 @@ package org.apache.spark.util +import scala.concurrent.Await + import akka.actor._ + +import org.scalatest.FunSuite + import org.apache.spark._ import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.scalatest.FunSuite -import scala.concurrent.Await /** * Test the AkkaUtils with various security settings. @@ -35,7 +38,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) @@ -106,13 +109,13 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + Seq((BlockManagerId("a", "hostA", 1000), size1000))) actorSystem.shutdown() slaveSystem.shutdown() @@ -157,13 +160,13 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security on and passwords match assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + Seq((BlockManagerId("a", "hostA", 1000), size1000))) actorSystem.shutdown() slaveSystem.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index 44332fc8dbc23..c3dd156b40514 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -26,13 +26,15 @@ import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec /** * Test writing files through the FileLogger. */ class FileLoggerSuite extends FunSuite with BeforeAndAfter { - private val fileSystem = Utils.getHadoopFileSystem("/") + private val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) private val allCompressionCodecs = Seq[String]( "org.apache.spark.io.LZFCompressionCodec", "org.apache.spark.io.SnappyCompressionCodec" diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 97ffb07662482..2b45d8b695853 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -21,6 +21,9 @@ import java.util.Properties import scala.collection.Map +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.scalatest.FunSuite @@ -35,13 +38,13 @@ class JsonProtocolSuite extends FunSuite { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) val stageCompleted = SparkListenerStageCompleted(makeStageInfo(101, 201, 301, 401L, 501L)) - val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 1, 444L, false)) + val taskStart = SparkListenerTaskStart(111, 0, makeTaskInfo(222L, 333, 1, 444L, false)) val taskGettingResult = SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 5, 3000L, true)) - val taskEnd = SparkListenerTaskEnd(1, "ShuffleMapTask", Success, + val taskEnd = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success, makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = false)) - val taskEndWithHadoopInput = SparkListenerTaskEnd(1, "ShuffleMapTask", Success, + val taskEndWithHadoopInput = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success, makeTaskInfo(123L, 234, 67, 345L, false), makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true)) val jobStart = SparkListenerJobStart(10, Seq[Int](1, 2, 3, 4), properties) @@ -52,12 +55,12 @@ class JsonProtocolSuite extends FunSuite { "System Properties" -> Seq(("Username", "guest"), ("Password", "guest")), "Classpath Entries" -> Seq(("Super library", "/tmp/super_library")) )) - val blockManagerAdded = SparkListenerBlockManagerAdded( - BlockManagerId("Stars", "In your multitude...", 300, 400), 500) - val blockManagerRemoved = SparkListenerBlockManagerRemoved( - BlockManagerId("Scarce", "to be counted...", 100, 200)) + val blockManagerAdded = SparkListenerBlockManagerAdded(1L, + BlockManagerId("Stars", "In your multitude...", 300), 500) + val blockManagerRemoved = SparkListenerBlockManagerRemoved(2L, + BlockManagerId("Scarce", "to be counted...", 100)) val unpersistRdd = SparkListenerUnpersistRDD(12345) - val applicationStart = SparkListenerApplicationStart("The winner of all", 42L, "Garfield") + val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) testEvent(stageSubmitted, stageSubmittedJsonString) @@ -81,7 +84,7 @@ class JsonProtocolSuite extends FunSuite { testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics(33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false)) - testBlockManagerId(BlockManagerId("Hong", "Kong", 500, 1000)) + testBlockManagerId(BlockManagerId("Hong", "Kong", 500)) // StorageLevel testStorageLevel(StorageLevel.NONE) @@ -104,7 +107,7 @@ class JsonProtocolSuite extends FunSuite { testJobResult(jobFailed) // TaskEndReason - val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15, 16), 17, 18, 19) + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19) val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) @@ -151,6 +154,35 @@ class JsonProtocolSuite extends FunSuite { assert(newMetrics.inputMetrics.isEmpty) } + test("BlockManager events backward compatibility") { + // SparkListenerBlockManagerAdded/Removed in Spark 1.0.0 do not have a "time" property. + val blockManagerAdded = SparkListenerBlockManagerAdded(1L, + BlockManagerId("Stars", "In your multitude...", 300), 500) + val blockManagerRemoved = SparkListenerBlockManagerRemoved(2L, + BlockManagerId("Scarce", "to be counted...", 100)) + + val oldBmAdded = JsonProtocol.blockManagerAddedToJson(blockManagerAdded) + .removeField({ _._1 == "Timestamp" }) + + val deserializedBmAdded = JsonProtocol.blockManagerAddedFromJson(oldBmAdded) + assert(SparkListenerBlockManagerAdded(-1L, blockManagerAdded.blockManagerId, + blockManagerAdded.maxMem) === deserializedBmAdded) + + val oldBmRemoved = JsonProtocol.blockManagerRemovedToJson(blockManagerRemoved) + .removeField({ _._1 == "Timestamp" }) + + val deserializedBmRemoved = JsonProtocol.blockManagerRemovedFromJson(oldBmRemoved) + assert(SparkListenerBlockManagerRemoved(-1L, blockManagerRemoved.blockManagerId) === + deserializedBmRemoved) + } + + test("SparkListenerApplicationStart backwards compatibility") { + // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. + val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") + val oldEvent = JsonProtocol.applicationStartToJson(applicationStart) + .removeField({ _._1 == "App ID" }) + assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent)) + } /** -------------------------- * | Helper test running methods | @@ -242,8 +274,10 @@ class JsonProtocolSuite extends FunSuite { assertEquals(e1.environmentDetails, e2.environmentDetails) case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) => assert(e1.maxMem === e2.maxMem) + assert(e1.time === e2.time) assertEquals(e1.blockManagerId, e2.blockManagerId) case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) => + assert(e1.time === e2.time) assertEquals(e1.blockManagerId, e2.blockManagerId) case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) => assert(e1.rddId == e2.rddId) @@ -343,7 +377,6 @@ class JsonProtocolSuite extends FunSuite { assert(bm1.executorId === bm2.executorId) assert(bm1.host === bm2.host) assert(bm1.port === bm2.port) - assert(bm1.nettyPort === bm2.nettyPort) } private def assertEquals(result1: JobResult, result2: JobResult) { @@ -397,7 +430,8 @@ class JsonProtocolSuite extends FunSuite { private def assertJsonStringEquals(json1: String, json2: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - assert(formatJsonString(json1) === formatJsonString(json2)) + assert(formatJsonString(json1) === formatJsonString(json2), + s"input ${formatJsonString(json1)} got ${formatJsonString(json2)}") } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { @@ -485,7 +519,7 @@ class JsonProtocolSuite extends FunSuite { private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) } - val stageInfo = new StageInfo(a, "greetings", b, rddInfos, "details") + val stageInfo = new StageInfo(a, 0, "greetings", b, rddInfos, "details") val (acc1, acc2) = (makeAccumulableInfo(1), makeAccumulableInfo(2)) stageInfo.accumulables(acc1.id) = acc1 stageInfo.accumulables(acc2.id) = acc2 @@ -558,84 +592,246 @@ class JsonProtocolSuite extends FunSuite { private val stageSubmittedJsonString = """ - {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":100,"Stage Name": - "greetings","Number of Tasks":200,"RDD Info":[],"Details":"details", - "Accumulables":[{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - {"ID":1,"Name":"Accumulable1","Update":"delta1","Value":"val1"}]},"Properties": - {"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}} + |{ + | "Event": "SparkListenerStageSubmitted", + | "Stage Info": { + | "Stage ID": 100, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 200, + | "RDD Info": [], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | }, + | "Properties": { + | "France": "Paris", + | "Germany": "Berlin", + | "Russia": "Moscow", + | "Ukraine": "Kiev" + | } + |} """ private val stageCompletedJsonString = """ - {"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":101,"Stage Name": - "greetings","Number of Tasks":201,"RDD Info":[{"RDD ID":101,"Name":"mayor","Storage - Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true, - "Replication":1},"Number of Partitions":201,"Number of Cached Partitions":301, - "Memory Size":401,"Tachyon Size":0,"Disk Size":501}],"Details":"details", - "Accumulables":[{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - {"ID":1,"Name":"Accumulable1","Update":"delta1","Value":"val1"}]}} + |{ + | "Event": "SparkListenerStageCompleted", + | "Stage Info": { + | "Stage ID": 101, + | "Stage Attempt ID": 0, + | "Stage Name": "greetings", + | "Number of Tasks": 201, + | "RDD Info": [ + | { + | "RDD ID": 101, + | "Name": "mayor", + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Number of Partitions": 201, + | "Number of Cached Partitions": 301, + | "Memory Size": 401, + | "Tachyon Size": 0, + | "Disk Size": 501 + | } + | ], + | "Details": "details", + | "Accumulables": [ + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | } + | ] + | } + |} """ private val taskStartJsonString = """ - |{"Event":"SparkListenerTaskStart","Stage ID":111,"Task Info":{"Task ID":222, - |"Index":333,"Attempt":1,"Launch Time":444,"Executor ID":"executor","Host":"your kind sir", - |"Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0, - |"Failed":false,"Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1", - |"Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - |{"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}]}} + |{ + | "Event": "SparkListenerTaskStart", + | "Stage ID": 111, + | "Stage Attempt ID": 0, + | "Task Info": { + | "Task ID": 222, + | "Index": 333, + | "Attempt": 1, + | "Launch Time": 444, + | "Executor ID": "executor", + | "Host": "your kind sir", + | "Locality": "NODE_LOCAL", + | "Speculative": false, + | "Getting Result Time": 0, + | "Finish Time": 0, + | "Failed": false, + | "Accumulables": [ + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | }, + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 3, + | "Name": "Accumulable3", + | "Update": "delta3", + | "Value": "val3" + | } + | ] + | } + |} """.stripMargin private val taskGettingResultJsonString = """ - |{"Event":"SparkListenerTaskGettingResult","Task Info": - | {"Task ID":1000,"Index":2000,"Attempt":5,"Launch Time":3000,"Executor ID":"executor", - | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":true,"Getting Result Time":0, - | "Finish Time":0,"Failed":false, - | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1", - | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}] + |{ + | "Event": "SparkListenerTaskGettingResult", + | "Task Info": { + | "Task ID": 1000, + | "Index": 2000, + | "Attempt": 5, + | "Launch Time": 3000, + | "Executor ID": "executor", + | "Host": "your kind sir", + | "Locality": "NODE_LOCAL", + | "Speculative": true, + | "Getting Result Time": 0, + | "Finish Time": 0, + | "Failed": false, + | "Accumulables": [ + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | }, + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 3, + | "Name": "Accumulable3", + | "Update": "delta3", + | "Value": "val3" + | } + | ] | } |} """.stripMargin private val taskEndJsonString = """ - |{"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask", - |"Task End Reason":{"Reason":"Success"}, - |"Task Info":{ - | "Task ID":123,"Index":234,"Attempt":67,"Launch Time":345,"Executor ID":"executor", - | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":false, - | "Getting Result Time":0,"Finish Time":0,"Failed":false, - | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1", - | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}] - |}, - |"Task Metrics":{ - | "Host Name":"localhost","Executor Deserialize Time":300,"Executor Run Time":400, - | "Result Size":500,"JVM GC Time":600,"Result Serialization Time":700, - | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0, - | "Shuffle Read Metrics":{ - | "Shuffle Finish Time":900, - | "Remote Blocks Fetched":800, - | "Local Blocks Fetched":700, - | "Fetch Wait Time":900, - | "Remote Bytes Read":1000 + |{ + | "Event": "SparkListenerTaskEnd", + | "Stage ID": 1, + | "Stage Attempt ID": 0, + | "Task Type": "ShuffleMapTask", + | "Task End Reason": { + | "Reason": "Success" | }, - | "Shuffle Write Metrics":{ - | "Shuffle Bytes Written":1200, - | "Shuffle Write Time":1500 + | "Task Info": { + | "Task ID": 123, + | "Index": 234, + | "Attempt": 67, + | "Launch Time": 345, + | "Executor ID": "executor", + | "Host": "your kind sir", + | "Locality": "NODE_LOCAL", + | "Speculative": false, + | "Getting Result Time": 0, + | "Finish Time": 0, + | "Failed": false, + | "Accumulables": [ + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | }, + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 3, + | "Name": "Accumulable3", + | "Update": "delta3", + | "Value": "val3" + | } + | ] | }, - | "Updated Blocks":[ - | {"Block ID":"rdd_0_0", - | "Status":{ - | "Storage Level":{ - | "Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false, - | "Replication":2 - | }, - | "Memory Size":0,"Tachyon Size":0,"Disk Size":0 + | "Task Metrics": { + | "Host Name": "localhost", + | "Executor Deserialize Time": 300, + | "Executor Run Time": 400, + | "Result Size": 500, + | "JVM GC Time": 600, + | "Result Serialization Time": 700, + | "Memory Bytes Spilled": 800, + | "Disk Bytes Spilled": 0, + | "Shuffle Read Metrics": { + | "Shuffle Finish Time": 900, + | "Remote Blocks Fetched": 800, + | "Local Blocks Fetched": 700, + | "Fetch Wait Time": 900, + | "Remote Bytes Read": 1000 + | }, + | "Shuffle Write Metrics": { + | "Shuffle Bytes Written": 1200, + | "Shuffle Write Time": 1500 + | }, + | "Updated Blocks": [ + | { + | "Block ID": "rdd_0_0", + | "Status": { + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": false, + | "Replication": 2 + | }, + | "Memory Size": 0, + | "Tachyon Size": 0, + | "Disk Size": 0 + | } | } - | } | ] | } |} @@ -643,80 +839,187 @@ class JsonProtocolSuite extends FunSuite { private val taskEndWithHadoopInputJsonString = """ - |{"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask", - |"Task End Reason":{"Reason":"Success"}, - |"Task Info":{ - | "Task ID":123,"Index":234,"Attempt":67,"Launch Time":345,"Executor ID":"executor", - | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":false, - | "Getting Result Time":0,"Finish Time":0,"Failed":false, - | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1", - | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"}, - | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}] - |}, - |"Task Metrics":{ - | "Host Name":"localhost","Executor Deserialize Time":300,"Executor Run Time":400, - | "Result Size":500,"JVM GC Time":600,"Result Serialization Time":700, - | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0, - | "Shuffle Write Metrics":{"Shuffle Bytes Written":1200,"Shuffle Write Time":1500}, - | "Input Metrics":{"Data Read Method":"Hadoop","Bytes Read":2100}, - | "Updated Blocks":[ - | {"Block ID":"rdd_0_0", - | "Status":{ - | "Storage Level":{ - | "Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false, - | "Replication":2 - | }, - | "Memory Size":0,"Tachyon Size":0,"Disk Size":0 + |{ + | "Event": "SparkListenerTaskEnd", + | "Stage ID": 1, + | "Stage Attempt ID": 0, + | "Task Type": "ShuffleMapTask", + | "Task End Reason": { + | "Reason": "Success" + | }, + | "Task Info": { + | "Task ID": 123, + | "Index": 234, + | "Attempt": 67, + | "Launch Time": 345, + | "Executor ID": "executor", + | "Host": "your kind sir", + | "Locality": "NODE_LOCAL", + | "Speculative": false, + | "Getting Result Time": 0, + | "Finish Time": 0, + | "Failed": false, + | "Accumulables": [ + | { + | "ID": 1, + | "Name": "Accumulable1", + | "Update": "delta1", + | "Value": "val1" + | }, + | { + | "ID": 2, + | "Name": "Accumulable2", + | "Update": "delta2", + | "Value": "val2" + | }, + | { + | "ID": 3, + | "Name": "Accumulable3", + | "Update": "delta3", + | "Value": "val3" + | } + | ] + | }, + | "Task Metrics": { + | "Host Name": "localhost", + | "Executor Deserialize Time": 300, + | "Executor Run Time": 400, + | "Result Size": 500, + | "JVM GC Time": 600, + | "Result Serialization Time": 700, + | "Memory Bytes Spilled": 800, + | "Disk Bytes Spilled": 0, + | "Shuffle Write Metrics": { + | "Shuffle Bytes Written": 1200, + | "Shuffle Write Time": 1500 + | }, + | "Input Metrics": { + | "Data Read Method": "Hadoop", + | "Bytes Read": 2100 + | }, + | "Updated Blocks": [ + | { + | "Block ID": "rdd_0_0", + | "Status": { + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use Tachyon": false, + | "Deserialized": false, + | "Replication": 2 + | }, + | "Memory Size": 0, + | "Tachyon Size": 0, + | "Disk Size": 0 + | } | } - | } - | ]} + | ] + | } |} """ private val jobStartJsonString = """ - {"Event":"SparkListenerJobStart","Job ID":10,"Stage IDs":[1,2,3,4],"Properties": - {"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}} + |{ + | "Event": "SparkListenerJobStart", + | "Job ID": 10, + | "Stage IDs": [ + | 1, + | 2, + | 3, + | 4 + | ], + | "Properties": { + | "France": "Paris", + | "Germany": "Berlin", + | "Russia": "Moscow", + | "Ukraine": "Kiev" + | } + |} """ private val jobEndJsonString = """ - {"Event":"SparkListenerJobEnd","Job ID":20,"Job Result":{"Result":"JobSucceeded"}} + |{ + | "Event": "SparkListenerJobEnd", + | "Job ID": 20, + | "Job Result": { + | "Result": "JobSucceeded" + | } + |} """ private val environmentUpdateJsonString = """ - {"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"GC speed":"9999 objects/s", - "Java home":"Land of coffee"},"Spark Properties":{"Job throughput":"80000 jobs/s, - regardless of job type"},"System Properties":{"Username":"guest","Password":"guest"}, - "Classpath Entries":{"Super library":"/tmp/super_library"}} + |{ + | "Event": "SparkListenerEnvironmentUpdate", + | "JVM Information": { + | "GC speed": "9999 objects/s", + | "Java home": "Land of coffee" + | }, + | "Spark Properties": { + | "Job throughput": "80000 jobs/s, regardless of job type" + | }, + | "System Properties": { + | "Username": "guest", + | "Password": "guest" + | }, + | "Classpath Entries": { + | "Super library": "/tmp/super_library" + | } + |} """ private val blockManagerAddedJsonString = """ - {"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"Stars", - "Host":"In your multitude...","Port":300,"Netty Port":400},"Maximum Memory":500} + |{ + | "Event": "SparkListenerBlockManagerAdded", + | "Block Manager ID": { + | "Executor ID": "Stars", + | "Host": "In your multitude...", + | "Port": 300 + | }, + | "Maximum Memory": 500, + | "Timestamp": 1 + |} """ private val blockManagerRemovedJsonString = """ - {"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"Scarce", - "Host":"to be counted...","Port":100,"Netty Port":200}} + |{ + | "Event": "SparkListenerBlockManagerRemoved", + | "Block Manager ID": { + | "Executor ID": "Scarce", + | "Host": "to be counted...", + | "Port": 100 + | }, + | "Timestamp": 2 + |} """ private val unpersistRDDJsonString = """ - {"Event":"SparkListenerUnpersistRDD","RDD ID":12345} + |{ + | "Event": "SparkListenerUnpersistRDD", + | "RDD ID": 12345 + |} """ private val applicationStartJsonString = """ - {"Event":"SparkListenerApplicationStart","App Name":"The winner of all","Timestamp":42, - "User":"Garfield"} + |{ + | "Event": "SparkListenerApplicationStart", + | "App Name": "The winner of all", + | "Timestamp": 42, + | "User": "Garfield" + |} """ private val applicationEndJsonString = """ - {"Event":"SparkListenerApplicationEnd","Timestamp":42} + |{ + | "Event": "SparkListenerApplicationEnd", + | "Timestamp": 42 + |} """ } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 04d7338488628..511d76c9144cc 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -23,37 +23,43 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.SparkContext._ +import org.apache.spark.io.CompressionCodec class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { + private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS + private def createCombiner[T](i: T) = ArrayBuffer[T](i) + private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i + private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] = + buf1 ++= buf2 - private def createCombiner(i: Int) = ArrayBuffer[Int](i) - private def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i - private def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + private def createExternalMap[T] = new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]]( + createCombiner[T], mergeValue[T], mergeCombiners[T]) - private def createSparkConf(loadDefaults: Boolean): SparkConf = { + private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = { val conf = new SparkConf(loadDefaults) // Make the Java serializer write a reset instruction (TC_RESET) after each object to test // for a bug we had with bytes written past the last object in a batch (SPARK-2792) conf.set("spark.serializer.objectStreamReset", "1") conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + conf.set("spark.shuffle.spill.compress", codec.isDefined.toString) + conf.set("spark.shuffle.compress", codec.isDefined.toString) + codec.foreach { c => conf.set("spark.io.compression.codec", c) } // Ensure that we actually have multiple batches per spill file conf.set("spark.shuffle.spill.batchSize", "10") conf } test("simple insert") { - val conf = createSparkConf(false) + val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) - - val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, - mergeValue, mergeCombiners) + val map = createExternalMap[Int] // Single insert map.insert(1, 10) var it = map.iterator assert(it.hasNext) val kv = it.next() - assert(kv._1 == 1 && kv._2 == ArrayBuffer[Int](10)) + assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10)) assert(!it.hasNext) // Multiple insert @@ -61,18 +67,17 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { map.insert(3, 30) it = map.iterator assert(it.hasNext) - assert(it.toSet == Set[(Int, ArrayBuffer[Int])]( + assert(it.toSet === Set[(Int, ArrayBuffer[Int])]( (1, ArrayBuffer[Int](10)), (2, ArrayBuffer[Int](20)), (3, ArrayBuffer[Int](30)))) + sc.stop() } test("insert with collision") { - val conf = createSparkConf(false) + val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) - - val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, - mergeValue, mergeCombiners) + val map = createExternalMap[Int] map.insertAll(Seq( (1, 10), @@ -84,30 +89,28 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val it = map.iterator assert(it.hasNext) val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) - assert(result == Set[(Int, Set[Int])]( + assert(result === Set[(Int, Set[Int])]( (1, Set[Int](10, 100, 1000)), (2, Set[Int](20, 200)), (3, Set[Int](30)))) + sc.stop() } test("ordering") { - val conf = createSparkConf(false) + val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) - val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, - mergeValue, mergeCombiners) + val map1 = createExternalMap[Int] map1.insert(1, 10) map1.insert(2, 20) map1.insert(3, 30) - val map2 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, - mergeValue, mergeCombiners) + val map2 = createExternalMap[Int] map2.insert(2, 20) map2.insert(3, 30) map2.insert(1, 10) - val map3 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, - mergeValue, mergeCombiners) + val map3 = createExternalMap[Int] map3.insert(3, 30) map3.insert(1, 10) map3.insert(2, 20) @@ -119,33 +122,33 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { var kv1 = it1.next() var kv2 = it2.next() var kv3 = it3.next() - assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) - assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + assert(kv1._1 === kv2._1 && kv2._1 === kv3._1) + assert(kv1._2 === kv2._2 && kv2._2 === kv3._2) kv1 = it1.next() kv2 = it2.next() kv3 = it3.next() - assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) - assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + assert(kv1._1 === kv2._1 && kv2._1 === kv3._1) + assert(kv1._2 === kv2._2 && kv2._2 === kv3._2) kv1 = it1.next() kv2 = it2.next() kv3 = it3.next() - assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) - assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + assert(kv1._1 === kv2._1 && kv2._1 === kv3._1) + assert(kv1._2 === kv2._2 && kv2._2 === kv3._2) + sc.stop() } test("null keys and values") { - val conf = createSparkConf(false) + val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) - val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, - mergeValue, mergeCombiners) + val map = createExternalMap[Int] map.insert(1, 5) map.insert(2, 6) map.insert(3, 7) assert(map.size === 3) - assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + assert(map.iterator.toSet === Set[(Int, Seq[Int])]( (1, Seq[Int](5)), (2, Seq[Int](6)), (3, Seq[Int](7)) @@ -155,7 +158,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val nullInt = null.asInstanceOf[Int] map.insert(nullInt, 8) assert(map.size === 4) - assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + assert(map.iterator.toSet === Set[(Int, Seq[Int])]( (1, Seq[Int](5)), (2, Seq[Int](6)), (3, Seq[Int](7)), @@ -167,32 +170,34 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { map.insert(nullInt, nullInt) assert(map.size === 5) val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) - assert(result == Set[(Int, Set[Int])]( + assert(result === Set[(Int, Set[Int])]( (1, Set[Int](5)), (2, Set[Int](6)), (3, Set[Int](7)), (4, Set[Int](nullInt)), (nullInt, Set[Int](nullInt, 8)) )) + sc.stop() } test("simple aggregator") { - val conf = createSparkConf(false) + val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) // reduceByKey val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) val result1 = rdd.reduceByKey(_+_).collect() - assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5))) + assert(result1.toSet === Set[(Int, Int)]((0, 5), (1, 5))) // groupByKey val result2 = rdd.groupByKey().collect().map(x => (x._1, x._2.toList)).toSet - assert(result2.toSet == Set[(Int, Seq[Int])] + assert(result2.toSet === Set[(Int, Seq[Int])] ((0, List[Int](1, 1, 1, 1, 1)), (1, List[Int](1, 1, 1, 1, 1)))) + sc.stop() } test("simple cogroup") { - val conf = createSparkConf(false) + val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) val rdd1 = sc.parallelize(1 to 4).map(i => (i, i)) val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i)) @@ -200,77 +205,98 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { result.foreach { case (i, (seq1, seq2)) => i match { - case 0 => assert(seq1.toSet == Set[Int]() && seq2.toSet == Set[Int](2, 4)) - case 1 => assert(seq1.toSet == Set[Int](1) && seq2.toSet == Set[Int](1, 3)) - case 2 => assert(seq1.toSet == Set[Int](2) && seq2.toSet == Set[Int]()) - case 3 => assert(seq1.toSet == Set[Int](3) && seq2.toSet == Set[Int]()) - case 4 => assert(seq1.toSet == Set[Int](4) && seq2.toSet == Set[Int]()) + case 0 => assert(seq1.toSet === Set[Int]() && seq2.toSet === Set[Int](2, 4)) + case 1 => assert(seq1.toSet === Set[Int](1) && seq2.toSet === Set[Int](1, 3)) + case 2 => assert(seq1.toSet === Set[Int](2) && seq2.toSet === Set[Int]()) + case 3 => assert(seq1.toSet === Set[Int](3) && seq2.toSet === Set[Int]()) + case 4 => assert(seq1.toSet === Set[Int](4) && seq2.toSet === Set[Int]()) } } + sc.stop() } test("spilling") { - val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + testSimpleSpilling() + } + + test("spilling with compression") { + // Keep track of which compression codec we're using to report in test failure messages + var lastCompressionCodec: Option[String] = None + try { + allCompressionCodecs.foreach { c => + lastCompressionCodec = Some(c) + testSimpleSpilling(Some(c)) + } + } catch { + // Include compression codec used in test failure message + // We need to catch Throwable here because assertion failures are not covered by Exceptions + case t: Throwable => + val compressionMessage = lastCompressionCodec + .map { c => "with compression using codec " + c } + .getOrElse("without compression") + val newException = new Exception(s"Test failed $compressionMessage:\n\n${t.getMessage}") + newException.setStackTrace(t.getStackTrace) + throw newException + } + } + + /** + * Test spilling through simple aggregations and cogroups. + * If a compression codec is provided, use it. Otherwise, do not compress spills. + */ + private def testSimpleSpilling(codec: Option[String] = None): Unit = { + val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) val resultA = rddA.reduceByKey(math.max).collect() - assert(resultA.length == 50000) - resultA.foreach { case(k, v) => - if (v != k * 2 + 1) { - fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") - } + assert(resultA.length === 50000) + resultA.foreach { case (k, v) => + assert(v === k * 2 + 1, s"Value for $k was wrong: expected ${k * 2 + 1}, got $v") } // groupByKey - should spill ~17 times val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) val resultB = rddB.groupByKey().collect() - assert(resultB.length == 25000) - resultB.foreach { case(i, seq) => + assert(resultB.length === 25000) + resultB.foreach { case (i, seq) => val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) - if (seq.toSet != expected) { - fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") - } + assert(seq.toSet === expected, + s"Value for $i was wrong: expected $expected, got ${seq.toSet}") } // cogroup - should spill ~7 times val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) val resultC = rddC1.cogroup(rddC2).collect() - assert(resultC.length == 10000) - resultC.foreach { case(i, (seq1, seq2)) => + assert(resultC.length === 10000) + resultC.foreach { case (i, (seq1, seq2)) => i match { case 0 => - assert(seq1.toSet == Set[Int](0)) - assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + assert(seq1.toSet === Set[Int](0)) + assert(seq2.toSet === Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) case 1 => - assert(seq1.toSet == Set[Int](1)) - assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) + assert(seq1.toSet === Set[Int](1)) + assert(seq2.toSet === Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) case 5000 => - assert(seq1.toSet == Set[Int](5000)) - assert(seq2.toSet == Set[Int]()) + assert(seq1.toSet === Set[Int](5000)) + assert(seq2.toSet === Set[Int]()) case 9999 => - assert(seq1.toSet == Set[Int](9999)) - assert(seq2.toSet == Set[Int]()) + assert(seq1.toSet === Set[Int](9999)) + assert(seq2.toSet === Set[Int]()) case _ => } } + sc.stop() } test("spilling with hash collisions") { - val conf = createSparkConf(true) + val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) = - buffer1 ++= buffer2 - - val map = new ExternalAppendOnlyMap[String, String, ArrayBuffer[String]]( - createCombiner, mergeValue, mergeCombiners) + val map = createExternalMap[String] val collisionPairs = Seq( ("Aa", "BB"), // 2112 @@ -312,13 +338,13 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { count += 1 } assert(count === 100000 + collisionPairs.size * 2) + sc.stop() } test("spilling with many hash collisions") { - val conf = createSparkConf(true) + val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes @@ -337,15 +363,14 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { count += 1 } assert(count === 10000) + sc.stop() } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = createSparkConf(true) + val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - - val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]]( - createCombiner, mergeValue, mergeCombiners) + val map = createExternalMap[Int] (1 to 100000).foreach { i => map.insert(i, i) } map.insert(Int.MaxValue, Int.MaxValue) @@ -355,15 +380,14 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { // Should not throw NoSuchElementException it.next() } + sc.stop() } test("spilling with null keys and values") { - val conf = createSparkConf(true) + val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - - val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]]( - createCombiner, mergeValue, mergeCombiners) + val map = createExternalMap[Int] map.insertAll((1 to 100000).iterator.map(i => (i, i))) map.insert(null.asInstanceOf[Int], 1) @@ -375,6 +399,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { // Should not throw NullPointerException it.next() } + sc.stop() } } diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala new file mode 100644 index 0000000000000..f855831b8e367 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import scala.util.Random + +import org.scalatest.FunSuite + + +class ByteArrayChunkOutputStreamSuite extends FunSuite { + + test("empty output") { + val o = new ByteArrayChunkOutputStream(1024) + assert(o.toArrays.length === 0) + } + + test("write a single byte") { + val o = new ByteArrayChunkOutputStream(1024) + o.write(10) + assert(o.toArrays.length === 1) + assert(o.toArrays.head.toSeq === Seq(10.toByte)) + } + + test("write a single near boundary") { + val o = new ByteArrayChunkOutputStream(10) + o.write(new Array[Byte](9)) + o.write(99) + assert(o.toArrays.length === 1) + assert(o.toArrays.head(9) === 99.toByte) + } + + test("write a single at boundary") { + val o = new ByteArrayChunkOutputStream(10) + o.write(new Array[Byte](10)) + o.write(99) + assert(o.toArrays.length === 2) + assert(o.toArrays(1).length === 1) + assert(o.toArrays(1)(0) === 99.toByte) + } + + test("single chunk output") { + val ref = new Array[Byte](8) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("single chunk output at boundary size") { + val ref = new Array[Byte](10) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("multiple chunk output") { + val ref = new Array[Byte](26) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 6) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 26)) + } + + test("multiple chunk output at boundary size") { + val ref = new Array[Byte](30) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 10) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 30)) + } +} diff --git a/dev/check-license b/dev/check-license index 625ec161bc571..9ff0929e9a5e8 100755 --- a/dev/check-license +++ b/dev/check-license @@ -23,18 +23,18 @@ acquire_rat_jar () { URL1="http://search.maven.org/remotecontent?filepath=org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" URL2="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" - JAR=$rat_jar + JAR="$rat_jar" if [[ ! -f "$rat_jar" ]]; then # Download rat launch jar if it hasn't been downloaded yet if [ ! -f "$JAR" ]; then # Download printf "Attempting to fetch rat\n" - JAR_DL=${JAR}.part + JAR_DL="${JAR}.part" if hash curl 2>/dev/null; then - (curl --progress-bar ${URL1} > "$JAR_DL" || curl --progress-bar ${URL2} > "$JAR_DL") && mv "$JAR_DL" "$JAR" + (curl --silent "${URL1}" > "$JAR_DL" || curl --silent "${URL2}" > "$JAR_DL") && mv "$JAR_DL" "$JAR" elif hash wget 2>/dev/null; then - (wget --progress=bar ${URL1} -O "$JAR_DL" || wget --progress=bar ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR" + (wget --quiet ${URL1} -O "$JAR_DL" || wget --quiet ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR" else printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 @@ -50,7 +50,7 @@ acquire_rat_jar () { } # Go to the Spark project root directory -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" if test -x "$JAVA_HOME/bin/java"; then @@ -60,17 +60,17 @@ else fi export RAT_VERSION=0.10 -export rat_jar=$FWDIR/lib/apache-rat-${RAT_VERSION}.jar -mkdir -p $FWDIR/lib +export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar +mkdir -p "$FWDIR"/lib [[ -f "$rat_jar" ]] || acquire_rat_jar || { echo "Download failed. Obtain the rat jar manually and place it at $rat_jar" exit 1 } -$java_cmd -jar $rat_jar -E $FWDIR/.rat-excludes -d $FWDIR > rat-results.txt +$java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt -ERRORS=$(cat rat-results.txt | grep -e "??") +ERRORS="$(cat rat-results.txt | grep -e "??")" if test ! -z "$ERRORS"; then echo "Could not find Apache license headers in the following files:" diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 28f26d2368254..281e8d4de6d71 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -60,14 +60,14 @@ if [[ ! "$@" =~ --package-only ]]; then -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ --batch-mode release:prepare mvn -DskipTests \ -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ release:perform cd .. @@ -117,12 +117,13 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & -make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & -make_binary_release "hadoop2" \ - "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & -make_binary_release "hadoop2-without-hive" \ - "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & +make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" & +make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & +make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" & +make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" & +make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & +make_binary_release "mapr3" "-Pmapr3 -Phive" & +make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" & wait # Copy data diff --git a/dev/create-release/generate-changelist.py b/dev/create-release/generate-changelist.py index de1b5d4ae1314..2e1a35a629342 100755 --- a/dev/create-release/generate-changelist.py +++ b/dev/create-release/generate-changelist.py @@ -125,8 +125,8 @@ def cleanup(ask=True): pr_num = [line.split()[1].lstrip("#") for line in body_lines if "Closes #" in line][0] github_url = "github.com/apache/spark/pull/%s" % pr_num day = time.strptime(date.split()[0], "%Y-%m-%d") - if day < SPARK_REPO_CHANGE_DATE1 or - (day < SPARK_REPO_CHANGE_DATE2 and pr_num < SPARK_REPO_PR_NUM_THRESH): + if (day < SPARK_REPO_CHANGE_DATE1 or + (day < SPARK_REPO_CHANGE_DATE2 and pr_num < SPARK_REPO_PR_NUM_THRESH)): github_url = "github.com/apache/incubator-spark/pull/%s" % pr_num append_to_changelist(" %s" % subject) diff --git a/dev/lint-python b/dev/lint-python index 4efddad839387..772f856154ae0 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -18,10 +18,10 @@ # SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" -SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" +SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" -cd $SPARK_ROOT_DIR +cd "$SPARK_ROOT_DIR" # Get pep8 at runtime so that we don't rely on it being installed on the build server. #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 @@ -30,6 +30,7 @@ cd $SPARK_ROOT_DIR #+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?)) PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py" +PEP8_PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" curl_status=$? @@ -44,7 +45,7 @@ fi #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python $PEP8_SCRIPT_PATH ./python > "$PEP8_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" $PEP8_PATHS_TO_CHECK > "$PEP8_REPORT_PATH" pep8_status=${PIPESTATUS[0]} #$? if [ $pep8_status -ne 0 ]; then @@ -54,7 +55,7 @@ else echo "PEP 8 checks passed." fi -rm -f "$PEP8_REPORT_PATH" +rm "$PEP8_REPORT_PATH" rm "$PEP8_SCRIPT_PATH" exit $pep8_status diff --git a/dev/mima b/dev/mima index 09e4482af5f3d..f9b9b03538f15 100755 --- a/dev/mima +++ b/dev/mima @@ -21,12 +21,12 @@ set -o pipefail set -e # Go to the Spark project root directory -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update -export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"` +export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore diff --git a/dev/run-tests b/dev/run-tests index 132f696d6447a..79401213a7fa2 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -18,7 +18,7 @@ # # Go to the Spark project root directory -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then @@ -55,7 +55,7 @@ JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..* # Partial solution for SPARK-1455. Only run Hive tests if there are sql changes. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master - diffs=`git diff --name-only master | grep "^sql/"` + diffs=`git diff --name-only master | grep "^\(sql/\)\|\(bin/spark-sql\)\|\(sbin/start-thriftserver.sh\)"` if [ -n "$diffs" ]; then echo "Detected changes in SQL. Will run Hive test suite." _RUN_SQL_TESTS=true @@ -89,17 +89,17 @@ echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" -# Build Spark; we always build with Hive because the PySpark SparkSQL tests need it. +# Build Spark; we always build with Hive because the PySpark Spark SQL tests need it. # echo "q" is needed because sbt on encountering a build file with failure # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. -BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver " +BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive " echo -e "q\n" | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly | \ grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled: if [ -n "$_RUN_SQL_TESTS" ]; then - SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" fi # echo "q" is needed because sbt on encountering a build file with failure # (either resolution or compilation) prompts the user for input either q, r, diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 31506e28e05af..06c3781eb3ccf 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -33,9 +33,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -# NOTE: Jenkins will kill the whole build after 120 minutes. -# Tests are a large part of that, but not all of it. -TESTS_TIMEOUT="120m" +TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout function post_message () { local message=$1 @@ -93,9 +91,14 @@ function post_message () { else merge_note=" * This patch merges cleanly." - non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ") + source_files=$( + git diff master --name-only \ + | grep -v -e "\/test" `# ignore files in test directories` \ + | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ + | tr "\n" " " + ) new_public_classes=$( - git diff master ${non_test_files} `# diff this patch against master and...` \ + git diff master ${source_files} `# diff this patch against master and...` \ | grep "^\+" `# filter in only added lines` \ | sed -r -e "s/^\+//g" `# remove the leading +` \ | grep -e "trait " -e "class " `# filter in lines with these key words` \ @@ -138,7 +141,8 @@ function post_message () { test_result="$?" if [ "$test_result" -eq "124" ]; then - fail_message="**Tests timed out** after a configured wait of \`${TESTS_TIMEOUT}\`." + fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** after \ + a configured wait of \`${TESTS_TIMEOUT}\`." post_message "$fail_message" exit $test_result else diff --git a/dev/scalastyle b/dev/scalastyle index b53053a04ff42..efb5f291ea3b7 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,9 +17,9 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt # Check style with YARN alpha built too -echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ +echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt # Check style with YARN built too echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \ diff --git a/docs/README.md b/docs/README.md index fd7ba4e0d72ea..0a0126c5747d1 100644 --- a/docs/README.md +++ b/docs/README.md @@ -30,7 +30,7 @@ called `_site` containing index.html as well as the rest of the compiled files. You can modify the default Jekyll build as follows: # Skip generating API docs (which takes a while) - $ SKIP_SCALADOC=1 jekyll build + $ SKIP_API=1 jekyll build # Serve content locally on port 4000 $ jekyll serve --watch # Build the site with extra features used on the live page diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 2dbbbf6feb4b8..3b02e090aec28 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -25,8 +25,8 @@ curr_dir = pwd cd("..") - puts "Running 'sbt/sbt compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `sbt/sbt compile unidoc` + puts "Running 'sbt/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `sbt/sbt -Pkinesis-asl compile unidoc` puts "Moving back into docs dir." cd("docs") diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 4d87ab92cec5b..bce7412c7d4c9 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -96,13 +96,12 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package {% endhighlight %} -# Building Thrift JDBC server and CLI for Spark SQL - -Spark SQL supports Thrift JDBC server and CLI. -See sql-programming-guide.md for more information about those features. -You can use those features by setting `-Phive-thriftserver` when building Spark as follows. +# Building With Hive and JDBC Support +To enable Hive integration for Spark SQL along with its JDBC server and CLI, +add the `-Phive` profile to your existing build options. {% highlight bash %} -mvn -Phive-thriftserver assembly +# Apache Hadoop 2.4.X with Hive support +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package {% endhighlight %} # Spark Tests in Maven diff --git a/docs/configuration.md b/docs/configuration.md index 981170d8b49b7..36178efb97103 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -214,6 +214,16 @@ Apart from these, the following properties are also available, and may be useful process. The user can specify multiple of these and to set multiple environment variables. + + spark.mesos.executor.home + driver side SPARK_HOME + + Set the directory in which Spark is installed on the executors in Mesos. By default, the + executors will simply use the driver's Spark home directory, which may not be visible to + them. Note that this is only relevant if a Spark binary package is not specified through + spark.executor.uri. + + #### Shuffle Behavior @@ -283,12 +293,11 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.manager - HASH + sort - Implementation to use for shuffling data. A hash-based shuffle manager is the default, but - starting in Spark 1.1 there is an experimental sort-based shuffle manager that is more - memory-efficient in environments with small executors, such as YARN. To use that, change - this value to SORT. + Implementation to use for shuffling data. There are two implementations available: + sort and hash. Sort-based shuffle is more memory-efficient and is + the default option starting in 1.2. diff --git a/docs/img/streaming-arch.png b/docs/img/streaming-arch.png index bc57b460fdf8b..ac35f1d34cf3d 100644 Binary files a/docs/img/streaming-arch.png and b/docs/img/streaming-arch.png differ diff --git a/docs/img/streaming-figures.pptx b/docs/img/streaming-figures.pptx index 1b18c2ee0ea3e..d1cc25e379f46 100644 Binary files a/docs/img/streaming-figures.pptx and b/docs/img/streaming-figures.pptx differ diff --git a/docs/img/streaming-kinesis-arch.png b/docs/img/streaming-kinesis-arch.png new file mode 100644 index 0000000000000..bea5fa88df985 Binary files /dev/null and b/docs/img/streaming-kinesis-arch.png differ diff --git a/docs/index.md b/docs/index.md index 4ac0982ae54f1..7fe6b43d32af7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -103,6 +103,8 @@ options for deployment: * [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware * [3rd Party Hadoop Distributions](hadoop-third-party-distributions.html): using common Hadoop distributions +* Integration with other storage systems: + * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system * [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index ab10b2f01f87b..d5c539db791be 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -43,6 +43,17 @@ level of confidence in observed user preferences, rather than explicit ratings g model then tries to find latent factors that can be used to predict the expected preference of a user for an item. +### Scaling of the regularization parameter + +Since v1.1, we scale the regularization parameter `lambda` in solving each least squares problem by +the number of ratings the user generated in updating user factors, +or the number of ratings the product received in updating product factors. +This approach is named "ALS-WR" and discussed in the paper +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +It makes `lambda` less dependent on the scale of the dataset. +So we can apply the best parameter learned from a sampled subset to the full dataset +and expect similar performance. + ## Examples
diff --git a/docs/mllib-basics.md b/docs/mllib-data-types.md similarity index 99% rename from docs/mllib-basics.md rename to docs/mllib-data-types.md index 8752df412950a..101dc2f8695f3 100644 --- a/docs/mllib-basics.md +++ b/docs/mllib-data-types.md @@ -1,7 +1,7 @@ --- layout: global -title: Basics - MLlib -displayTitle: MLlib - Basics +title: Data Types - MLlib +displayTitle: MLlib - Data Types --- * Table of contents diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index c01a92a9a1b26..12a6afbeea829 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -7,20 +7,26 @@ displayTitle: MLlib - Decision Tree * Table of contents {:toc} -Decision trees and their ensembles are popular methods for the machine learning tasks of +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of classification and regression. Decision trees are widely used since they are easy to interpret, -handle categorical variables, extend to the multiclass classification setting, do not require +handle categorical features, extend to the multiclass classification setting, do not require feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble -algorithms such as decision forest and boosting are among the top performers for classification and +algorithms such as random forests and boosting are among the top performers for classification and regression tasks. +MLlib supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions of instances. + ## Basic algorithm The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature -space by choosing a single element from the *best split set* where each element of the set maximizes -the information gain at a tree node. In other words, the split chosen at each tree node is chosen -from the set `$\underset{s}{\operatorname{argmax}} IG(D,s)$` where `$IG(D,s)$` is the information -gain when a split `$s$` is applied to a dataset `$D$`. +space. The tree predicts the same label for each bottommost (leaf) partition. +Each partition is chosen greedily by selecting the *best split* from a set of possible splits, +in order to maximize the information gain at a tree node. In other words, the split chosen at each +tree node is chosen from the set `$\underset{s}{\operatorname{argmax}} IG(D,s)$` where `$IG(D,s)$` +is the information gain when a split `$s$` is applied to a dataset `$D$`. ### Node impurity and information gain @@ -52,9 +58,10 @@ impurity measure for regression (variance). -The *information gain* is the difference in the parent node impurity and the weighted sum of the two -child node impurities. Assuming that a split $s$ partitions the dataset `$D$` of size `$N$` into two -datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, respectively: +The *information gain* is the difference between the parent node impurity and the weighted sum of +the two child node impurities. Assuming that a split $s$ partitions the dataset `$D$` of size `$N$` +into two datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, +respectively, the information gain is: `$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$` @@ -62,124 +69,331 @@ datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, **Continuous features** -For small datasets in single machine implementations, the split candidates for each continuous +For small datasets in single-machine implementations, the split candidates for each continuous feature are typically the unique values for the feature. Some implementations sort the feature values and then use the ordered unique values as split candidates for faster tree calculations. -Finding ordered unique feature values is computationally intensive for large distributed -datasets. One can get an approximate set of split candidates by performing a quantile calculation -over a sampled fraction of the data. The ordered splits create "bins" and the maximum number of such -bins can be specified using the `maxBins` parameters. +Sorting feature values is expensive for large distributed datasets. +This implementation computes an approximate set of split candidates by performing a quantile +calculation over a sampled fraction of the data. +The ordered splits create "bins" and the maximum number of such +bins can be specified using the `maxBins` parameter. Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario -since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of +since the default `maxBins` value is 32). The tree algorithm automatically reduces the number of bins if the condition is not satisfied. **Categorical features** -For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For -binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the -categorical feature values by the proportion of labels falling in one of the two classes (see -Section 9.2.4 in +For a categorical feature with `$M$` possible values (categories), one could come up with +`$2^{M-1}-1$` split candidates. For binary (0/1) classification and regression, +we can reduce the number of split candidates to `$M-1$` by ordering the +categorical feature values by the average label. (See Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for -details). For example, for a binary classification problem with one categorical feature with three -categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical -features are ordered as A followed by C followed B or A, C, B. The two split candidates are A \| C, B -and A , C \| B where \| denotes the split. A similar heuristic is used for multiclass classification -when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value -is used for ordering. +details.) For example, for a binary classification problem with one categorical feature with three +categories A, B and C whose corresponding proportions of label 1 are 0.2, 0.6 and 0.4, the categorical +features are ordered as A, C, B. The two split candidates are A \| C, B +and A , C \| B where \| denotes the split. + +In multiclass classification, all `$2^{M-1}-1$` possible splits are used whenever possible. +When `$2^{M-1}-1$` is greater than the `maxBins` parameter, we use a (heuristic) method +similar to the method used for binary classification and regression. +The `$M$` categorical feature values are ordered by impurity, +and the resulting `$M-1$` split candidates are considered. ### Stopping rule The recursive tree construction is stopped at a node when one of the two conditions is met: -1. The node depth is equal to the `maxDepth` training parameter +1. The node depth is equal to the `maxDepth` training parameter. 2. No split candidate leads to an information gain at the node. +## Implementation details + ### Max memory requirements -For faster processing, the decision tree algorithm performs simultaneous histogram computations for all nodes at each level of the tree. This could lead to high memory requirements at deeper levels of the tree leading to memory overflow errors. To alleviate this problem, a 'maxMemoryInMB' training parameter is provided which specifies the maximum amount of memory at the workers (twice as much at the master) to be allocated to the histogram computation. The default value is conservatively chosen to be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements for a level-wise computation crosses the `maxMemoryInMB` threshold, the node training tasks at each subsequent level is split into smaller tasks. +For faster processing, the decision tree algorithm performs simultaneous histogram computations for +all nodes at each level of the tree. This could lead to high memory requirements at deeper levels +of the tree, potentially leading to memory overflow errors. To alleviate this problem, a `maxMemoryInMB` +training parameter specifies the maximum amount of memory at the workers (twice as much at the +master) to be allocated to the histogram computation. The default value is conservatively chosen to +be 256 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements +for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each +subsequent level are split into smaller tasks. + +Note that, if you have a large amount of memory, increasing `maxMemoryInMB` can lead to faster +training by requiring fewer passes over the data. + +### Binning feature values + +Increasing `maxBins` allows the algorithm to consider more split candidates and make fine-grained +split decisions. However, it also increases computation and communication. + +Note that the `maxBins` parameter must be at least the maximum number of categories `$M$` for +any categorical feature. + +### Scaling -### Practical limitations +Computation scales approximately linearly in the number of training instances, +in the number of features, and in the `maxBins` parameter. +Communication scales approximately linearly in the number of features and in `maxBins`. -1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input. -2. Python is not supported in this release. +The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input. ## Examples ### Classification -The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then -perform classification using a decision tree using Gini impurity as an impurity measure and a +The example below demonstrates how to load a +[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), +parse it as an RDD of `LabeledPoint` and then +perform classification using a decision tree with Gini impurity as an impurity measure and a maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy.
+
{% highlight scala %} -import org.apache.spark.SparkContext import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Gini - -// Load and parse the data file -val data = sc.textFile("data/mllib/sample_tree_data.csv") -val parsedData = data.map { line => - val parts = line.split(',').map(_.toDouble) - LabeledPoint(parts(0), Vectors.dense(parts.tail)) -} +import org.apache.spark.mllib.util.MLUtils -// Run training algorithm to build the model +// Load and parse the data file. +// Cache the data since we will use it again to compute training error. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache() + +// Train a DecisionTree model. +// Empty categoricalFeaturesInfo indicates all features are continuous. +val numClasses = 2 +val categoricalFeaturesInfo = Map[Int, Int]() +val impurity = "gini" val maxDepth = 5 -val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth) +val maxBins = 32 + +val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) -// Evaluate model on training examples and compute training error -val labelAndPreds = parsedData.map { point => +// Evaluate model on training instances and compute training error +val labelAndPreds = data.map { point => val prediction = model.predict(point.features) (point.label, prediction) } -val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.count +val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / data.count println("Training Error = " + trainErr) +println("Learned classification tree model:\n" + model) +{% endhighlight %} +
+ +
+{% highlight java %} +import java.util.HashMap; +import scala.Tuple2; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; + +SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); + +// Load and parse the data file. +// Cache the data since we will use it again to compute training error. +String datapath = "data/mllib/sample_libsvm_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + +// Set parameters. +// Empty categoricalFeaturesInfo indicates all features are continuous. +Integer numClasses = 2; +HashMap categoricalFeaturesInfo = new HashMap(); +String impurity = "gini"; +Integer maxDepth = 5; +Integer maxBins = 32; + +// Train a DecisionTree model for classification. +final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + +// Evaluate model on training instances and compute training error +JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); +Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); +System.out.println("Training error: " + trainErr); +System.out.println("Learned classification tree model:\n" + model); +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils + +# Load and parse the data file into an RDD of LabeledPoint. +# Cache the data since we will use it again to compute training error. +data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache() + +# Train a DecisionTree model. +# Empty categoricalFeaturesInfo indicates all features are continuous. +model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={}, + impurity='gini', maxDepth=5, maxBins=32) + +# Evaluate model on training instances and compute training error +predictions = model.predict(data.map(lambda x: x.features)) +labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions) +trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count()) +print('Training Error = ' + str(trainErr)) +print('Learned classification tree model:') +print(model) {% endhighlight %} + +Note: When making predictions for a dataset, it is more efficient to do batch prediction rather +than separately calling `predict` on each data point. This is because the Python code makes calls +to an underlying `DecisionTree` model in Scala.
+
### Regression -The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then -perform regression using a decision tree using variance as an impurity measure and a maximum tree +The example below demonstrates how to load a +[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), +parse it as an RDD of `LabeledPoint` and then +perform regression using a decision tree with variance as an impurity measure and a maximum tree depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
+
{% highlight scala %} -import org.apache.spark.SparkContext import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Variance - -// Load and parse the data file -val data = sc.textFile("data/mllib/sample_tree_data.csv") -val parsedData = data.map { line => - val parts = line.split(',').map(_.toDouble) - LabeledPoint(parts(0), Vectors.dense(parts.tail)) -} +import org.apache.spark.mllib.util.MLUtils -// Run training algorithm to build the model +// Load and parse the data file. +// Cache the data since we will use it again to compute training error. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache() + +// Train a DecisionTree model. +// Empty categoricalFeaturesInfo indicates all features are continuous. +val categoricalFeaturesInfo = Map[Int, Int]() +val impurity = "variance" val maxDepth = 5 -val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth) +val maxBins = 32 + +val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) -// Evaluate model on training examples and compute training error -val valuesAndPreds = parsedData.map { point => +// Evaluate model on training instances and compute training error +val labelsAndPredictions = data.map { point => val prediction = model.predict(point.features) (point.label, prediction) } -val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("training Mean Squared Error = " + MSE) +val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() +println("Training Mean Squared Error = " + trainMSE) +println("Learned regression tree model:\n" + model) {% endhighlight %}
+ +
+{% highlight java %} +import java.util.HashMap; +import scala.Tuple2; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; + +// Load and parse the data file. +// Cache the data since we will use it again to compute training error. +String datapath = "data/mllib/sample_libsvm_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + +SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); + +// Set parameters. +// Empty categoricalFeaturesInfo indicates all features are continuous. +HashMap categoricalFeaturesInfo = new HashMap(); +String impurity = "variance"; +Integer maxDepth = 5; +Integer maxBins = 32; + +// Train a DecisionTree model. +final DecisionTreeModel model = DecisionTree.trainRegressor(data, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + +// Evaluate model on training instances and compute training error +JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); +Double trainMSE = + predictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); +System.out.println("Training Mean Squared Error: " + trainMSE); +System.out.println("Learned regression tree model:\n" + model); +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils + +# Load and parse the data file into an RDD of LabeledPoint. +# Cache the data since we will use it again to compute training error. +data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache() + +# Train a DecisionTree model. +# Empty categoricalFeaturesInfo indicates all features are continuous. +model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={}, + impurity='variance', maxDepth=5, maxBins=32) + +# Evaluate model on training instances and compute training error +predictions = model.predict(data.map(lambda x: x.features)) +labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions) +trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count()) +print('Training Mean Squared Error = ' + str(trainMSE)) +print('Learned regression tree model:') +print(model) +{% endhighlight %} + +Note: When making predictions for a dataset, it is more efficient to do batch prediction rather +than separately calling `predict` on each data point. This is because the Python code makes calls +to an underlying `DecisionTree` model in Scala. +
+
diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 065d646496131..21cb35b4270ca 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -11,7 +11,7 @@ displayTitle: MLlib - Dimensionality Reduction of reducing the number of variables under consideration. It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -MLlib provides support for dimensionality reduction on tall-and-skinny matrices. +MLlib provides support for dimensionality reduction on the RowMatrix class. ## Singular value decomposition (SVD) @@ -39,8 +39,26 @@ If we keep the top $k$ singular values, then the dimensions of the resulting low * `$\Sigma$`: `$k \times k$`, * `$V$`: `$n \times k$`. -MLlib provides SVD functionality to row-oriented matrices that have only a few columns, -say, less than $1000$, but many rows, i.e., *tall-and-skinny* matrices. +### Performance +We assume $n$ is smaller than $m$. The singular values and the right singular vectors are derived +from the eigenvalues and the eigenvectors of the Gramian matrix $A^T A$. The matrix +storing the left singular vectors $U$, is computed via matrix multiplication as +$U = A (V S^{-1})$, if requested by the user via the computeU parameter. +The actual method to use is determined automatically based on the computational cost: + +* If $n$ is small ($n < 100$) or $k$ is large compared with $n$ ($k > n / 2$), we compute the Gramian matrix +first and then compute its top eigenvalues and eigenvectors locally on the driver. +This requires a single pass with $O(n^2)$ storage on each executor and on the driver, and +$O(n^2 k)$ time on the driver. +* Otherwise, we compute $(A^T A) v$ in a distributive way and send it to +ARPACK to +compute $(A^T A)$'s top eigenvalues and eigenvectors on the driver node. This requires $O(k)$ +passes, $O(n)$ storage on each executor, and $O(n k)$ storage on the driver. + +### SVD Example + +MLlib provides SVD functionality to row-oriented matrices, provided in the +RowMatrix class.
@@ -124,9 +142,8 @@ MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format.
-The following code demonstrates how to compute principal components on a tall-and-skinny `RowMatrix` +The following code demonstrates how to compute principal components on a `RowMatrix` and use them to project the vectors into a low-dimensional space. -The number of columns should be small, e.g, less than 1000. {% highlight scala %} import org.apache.spark.mllib.linalg.Matrix @@ -144,7 +161,7 @@ val projected: RowMatrix = mat.multiply(pc)
-The following code demonstrates how to compute principal components on a tall-and-skinny `RowMatrix` +The following code demonstrates how to compute principal components on a `RowMatrix` and use them to project the vectors into a low-dimensional space. The number of columns should be small, e.g, less than 1000. diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 4b3cb715c58c7..44f0f76220b6e 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -1,15 +1,94 @@ --- layout: global -title: Feature Extraction - MLlib -displayTitle: MLlib - Feature Extraction +title: Feature Extraction and Transformation - MLlib +displayTitle: MLlib - Feature Extraction and Transformation --- * Table of contents {:toc} + +## TF-IDF + +[Term frequency-inverse document frequency (TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a feature +vectorization method widely used in text mining to reflect the importance of a term to a document in the corpus. +Denote a term by `$t$`, a document by `$d$`, and the corpus by `$D$`. +Term frequency `$TF(t, d)$` is the number of times that term `$t$` appears in document `$d$`, +while document frequency `$DF(t, D)$` is the number of documents that contains term `$t$`. +If we only use term frequency to measure the importance, it is very easy to over-emphasize terms that +appear very often but carry little information about the document, e.g., "a", "the", and "of". +If a term appears very often across the corpus, it means it doesn't carry special information about +a particular document. +Inverse document frequency is a numerical measure of how much information a term provides: +`\[ +IDF(t, D) = \log \frac{|D| + 1}{DF(t, D) + 1}, +\]` +where `$|D|$` is the total number of documents in the corpus. +Since logarithm is used, if a term appears in all documents, its IDF value becomes 0. +Note that a smoothing term is applied to avoid dividing by zero for terms outside the corpus. +The TF-IDF measure is simply the product of TF and IDF: +`\[ +TFIDF(t, d, D) = TF(t, d) \cdot IDF(t, D). +\]` +There are several variants on the definition of term frequency and document frequency. +In MLlib, we separate TF and IDF to make them flexible. + +Our implementation of term frequency utilizes the +[hashing trick](http://en.wikipedia.org/wiki/Feature_hashing). +A raw feature is mapped into an index (term) by applying a hash function. +Then term frequencies are calculated based on the mapped indices. +This approach avoids the need to compute a global term-to-index map, +which can be expensive for a large corpus, but it suffers from potential hash collisions, +where different raw features may become the same term after hashing. +To reduce the chance of collision, we can increase the target feature dimension, i.e., +the number of buckets of the hash table. +The default feature dimension is `$2^{20} = 1,048,576$`. + +**Note:** MLlib doesn't provide tools for text segmentation. +We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and +[scalanlp/chalk](https://github.com/scalanlp/chalk). + +
+
+ +TF and IDF are implemented in [HashingTF](api/scala/index.html#org.apache.spark.mllib.feature.HashingTF) +and [IDF](api/scala/index.html#org.apache.spark.mllib.feature.IDF). +`HashingTF` takes an `RDD[Iterable[_]]` as the input. +Each record could be an iterable of strings or other types. + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.mllib.feature.HashingTF +import org.apache.spark.mllib.linalg.Vector + +val sc: SparkContext = ... + +// Load documents (one per line). +val documents: RDD[Seq[String]] = sc.textFile("...").map(_.split(" ").toSeq) + +val hashingTF = new HashingTF() +val tf: RDD[Vector] = hasingTF.transform(documents) +{% endhighlight %} + +While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes: +first to compute the IDF vector and second to scale the term frequencies by IDF. + +{% highlight scala %} +import org.apache.spark.mllib.feature.IDF + +// ... continue from the previous example +tf.cache() +val idf = new IDF().fit(tf) +val tfidf: RDD[Vector] = idf.transform(tf) +{% endhighlight %} +
+
+ ## Word2Vec -Word2Vec computes distributed vector representation of words. The main advantage of the distributed +[Word2Vec](https://code.google.com/p/word2vec/) computes distributed vector representation of words. +The main advantage of the distributed representations is that similar words are close in the vector space, which makes generalization to novel patterns easier and model estimation more robust. Distributed vector representation is showed to be useful in many natural language processing applications such as named entity @@ -70,4 +149,107 @@ for((synonym, cosineSimilarity) <- synonyms) {
-## TFIDF \ No newline at end of file +## StandardScaler + +Standardizes features by scaling to unit variance and/or removing the mean using column summary +statistics on the samples in the training set. This is a very common pre-processing step. + +For example, RBF kernel of Support Vector Machines or the L1 and L2 regularized linear models +typically work better when all features have unit variance and/or zero mean. + +Standardization can improve the convergence rate during the optimization process, and also prevents +against features with very large variances exerting an overly large influence during model training. + +### Model Fitting + +[`StandardScaler`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) has the +following parameters in the constructor: + +* `withMean` False by default. Centers the data with mean before scaling. It will build a dense +output, so this does not work on sparse input and will raise an exception. +* `withStd` True by default. Scales the data to unit variance. + +We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in +`StandardScaler` which can take an input of `RDD[Vector]`, learn the summary statistics, and then +return a model which can transform the input dataset into unit variance and/or zero mean features +depending how we configure the `StandardScaler`. + +This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) +which can apply the standardization on a `Vector` to produce a transformed `Vector` or on +an `RDD[Vector]` to produce a transformed `RDD[Vector]`. + +Note that if the variance of a feature is zero, it will return default `0.0` value in the `Vector` +for that feature. + +### Example + +The example below demonstrates how to load a dataset in libsvm format, and standardize the features +so that the new features have unit variance and/or zero mean. + +
+
+{% highlight scala %} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.feature.StandardScaler +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils + +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + +val scaler1 = new StandardScaler().fit(data.map(x => x.features)) +val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features)) + +// data1 will be unit variance. +val data1 = data.map(x => (x.label, scaler1.transform(x.features))) + +// Without converting the features into dense vectors, transformation with zero mean will raise +// exception on sparse vector. +// data2 will be unit variance and zero mean. +val data2 = data.map(x => (x.label, scaler2.transform(Vectors.dense(x.features.toArray)))) +{% endhighlight %} +
+
+ +## Normalizer + +Normalizer scales individual samples to have unit $L^p$ norm. This is a common operation for text +classification or clustering. For example, the dot product of two $L^2$ normalized TF-IDF vectors +is the cosine similarity of the vectors. + +[`Normalizer`](api/scala/index.html#org.apache.spark.mllib.feature.Normalizer) has the following +parameter in the constructor: + +* `p` Normalization in $L^p$ space, $p = 2$ by default. + +`Normalizer` implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) +which can apply the normalization on a `Vector` to produce a transformed `Vector` or on +an `RDD[Vector]` to produce a transformed `RDD[Vector]`. + +Note that if the norm of the input is zero, it will return the input vector. + +### Example + +The example below demonstrates how to load a dataset in libsvm format, and normalizes the features +with $L^2$ norm, and $L^\infty$ norm. + +
+
+{% highlight scala %} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.feature.Normalizer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils + +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + +val normalizer1 = new Normalizer() +val normalizer2 = new Normalizer(p = Double.PositiveInfinity) + +// Each sample in data1 will be normalized using $L^2$ norm. +val data1 = data.map(x => (x.label, normalizer1.transform(x.features))) + +// Each sample in data2 will be normalized using $L^\infty$ norm. +val data2 = data.map(x => (x.label, normalizer2.transform(x.features))) +{% endhighlight %} +
+
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index ca0a84a8c53fd..94fc98ce4fabe 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -7,12 +7,13 @@ MLlib is Spark's scalable machine learning library consisting of common learning including classification, regression, clustering, collaborative filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: -* [Data types](mllib-basics.html) -* [Basic statistics](mllib-stats.html) - * random data generation - * stratified sampling +* [Data types](mllib-data-types.html) +* [Basic statistics](mllib-statistics.html) * summary statistics + * correlations + * stratified sampling * hypothesis testing + * random data generation * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) * [decision trees](mllib-decision-tree.html) @@ -35,18 +36,23 @@ and the migration guide below will explain all changes between releases. # Dependencies -MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on -[netlib-java](https://github.com/fommil/netlib-java), and -[jblas](https://github.com/mikiobraun/jblas). +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), +which depends on [netlib-java](https://github.com/fommil/netlib-java), +and [jblas](https://github.com/mikiobraun/jblas). `netlib-java` and `jblas` depend on native Fortran routines. You need to install the -[gfortran runtime library](https://github.com/mikiobraun/jblas/wiki/Missing-Libraries) if it is not -already present on your nodes. MLlib will throw a linking error if it cannot detect these libraries -automatically. Due to license issues, we do not include `netlib-java`'s native libraries in MLlib's -dependency set. If no native library is available at runtime, you will see a warning message. To -use native libraries from `netlib-java`, please include artifact -`com.github.fommil.netlib:all:1.1.2` as a dependency of your project or build your own (see -[instructions](https://github.com/fommil/netlib-java/blob/master/README.md#machine-optimised-system-libraries)). +[gfortran runtime library](https://github.com/mikiobraun/jblas/wiki/Missing-Libraries) +if it is not already present on your nodes. +MLlib will throw a linking error if it cannot detect these libraries automatically. +Due to license issues, we do not include `netlib-java`'s native libraries in MLlib's +dependency set under default settings. +If no native library is available at runtime, you will see a warning message. +To use native libraries from `netlib-java`, please build Spark with `-Pnetlib-lgpl` or +include `com.github.fommil.netlib:all:1.1.2` as a dependency of your project. +If you want to use optimized BLAS/LAPACK libraries such as +[OpenBLAS](http://www.openblas.net/), please link its shared libraries to +`/usr/lib/libblas.so.3` and `/usr/lib/liblapack.so.3`, respectively. +BLAS/LAPACK libraries on worker nodes should be built without multithreading. To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. @@ -54,6 +60,32 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 # Migration Guide +## From 1.0 to 1.1 + +The only API changes in MLlib v1.1 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.1: + +1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match +the implementations of trees in +[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) +and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). +In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. +In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. +This depth is specified by the `maxDepth` parameter in +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` +methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +rather than using the old parameter class `Strategy`. These new training methods explicitly +separate classification and regression, and they replace specialized parameter types with +simple `String` types. + +Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +[Decision Trees Guide](mllib-decision-tree.html#examples). + ## From 0.9 to 1.0 In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few @@ -79,7 +111,7 @@ val vector: Vector = Vectors.dense(array) // a dense vector [`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to create sparse vectors. -*Note*. Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. +*Note*: Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`.
diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md new file mode 100644 index 0000000000000..c4632413991f1 --- /dev/null +++ b/docs/mllib-statistics.md @@ -0,0 +1,457 @@ +--- +layout: global +title: Basic Statistics - MLlib +displayTitle: MLlib - Basic Statistics +--- + +* Table of contents +{:toc} + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +## Summary statistics + +We provide column summary statistics for `RDD[Vector]` through the function `colStats` +available in `Statistics`. + +
+
+ +[`colStats()`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) returns an instance of +[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} + +val observations: RDD[Vector] = ... // an RDD of Vectors + +// Compute column summary statistics. +val summary: MultivariateStatisticalSummary = Statistics.colStats(observations) +println(summary.mean) // a dense vector containing the mean value for each column +println(summary.variance) // column-wise variance +println(summary.numNonzeros) // number of nonzeros in each column + +{% endhighlight %} +
+ +
+ +[`colStats()`](api/java/org/apache/spark/mllib/stat/Statistics.html) returns an instance of +[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; +import org.apache.spark.mllib.stat.Statistics; + +JavaSparkContext jsc = ... + +JavaRDD mat = ... // an RDD of Vectors + +// Compute column summary statistics. +MultivariateStatisticalSummary summary = Statistics.colStats(mat.rdd()); +System.out.println(summary.mean()); // a dense vector containing the mean value for each column +System.out.println(summary.variance()); // column-wise variance +System.out.println(summary.numNonzeros()); // number of nonzeros in each column + +{% endhighlight %} +
+ +
+[`colStats()`](api/python/pyspark.mllib.stat.Statistics-class.html#colStats) returns an instance of +[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.stat.MultivariateStatisticalSummary-class.html), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight python %} +from pyspark.mllib.stat import Statistics + +sc = ... # SparkContext + +mat = ... # an RDD of Vectors + +# Compute column summary statistics. +summary = Statistics.colStats(mat) +print summary.mean() +print summary.variance() +print summary.numNonzeros() + +{% endhighlight %} +
+ +
+ +## Correlations + +Calculating the correlation between two series of data is a common operation in Statistics. In MLlib +we provide the flexibility to calculate pairwise correlations among many series. The supported +correlation methods are currently Pearson's and Spearman's correlation. + +
+
+[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or +an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.stat.Statistics + +val sc: SparkContext = ... + +val seriesX: RDD[Double] = ... // a series +val seriesY: RDD[Double] = ... // must have the same number of partitions and cardinality as seriesX + +// compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a +// method is not specified, Pearson's method will be used by default. +val correlation: Double = Statistics.corr(seriesX, seriesY, "pearson") + +val data: RDD[Vector] = ... // note that each Vector is a row and not a column + +// calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. +// If a method is not specified, Pearson's method will be used by default. +val correlMatrix: Matrix = Statistics.corr(data, "pearson") + +{% endhighlight %} +
+ +
+[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to +calculate correlations between series. Depending on the type of input, two `JavaDoubleRDD`s or +a `JavaRDD`, the output will be a `Double` or the correlation `Matrix` respectively. + +{% highlight java %} +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.*; +import org.apache.spark.mllib.stat.Statistics; + +JavaSparkContext jsc = ... + +JavaDoubleRDD seriesX = ... // a series +JavaDoubleRDD seriesY = ... // must have the same number of partitions and cardinality as seriesX + +// compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a +// method is not specified, Pearson's method will be used by default. +Double correlation = Statistics.corr(seriesX.srdd(), seriesY.srdd(), "pearson"); + +JavaRDD data = ... // note that each Vector is a row and not a column + +// calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. +// If a method is not specified, Pearson's method will be used by default. +Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson"); + +{% endhighlight %} +
+ +
+[`Statistics`](api/python/pyspark.mllib.stat.Statistics-class.html) provides methods to +calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or +an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. + +{% highlight python %} +from pyspark.mllib.stat import Statistics + +sc = ... # SparkContext + +seriesX = ... # a series +seriesY = ... # must have the same number of partitions and cardinality as seriesX + +# Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a +# method is not specified, Pearson's method will be used by default. +print Statistics.corr(seriesX, seriesY, method="pearson") + +data = ... # an RDD of Vectors +# calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. +# If a method is not specified, Pearson's method will be used by default. +print Statistics.corr(data, method="pearson") + +{% endhighlight %} +
+ +
+ +## Stratified sampling + +Unlike the other statistics functions, which reside in MLLib, stratified sampling methods, +`sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified +sampling, the keys can be thought of as a label and the value as a specific attribute. For example +the key can be man or woman, or document ids, and the respective values can be the list of ages +of the people in the population or the list of words in the documents. The `sampleByKey` method +will flip a coin to decide whether an observation will be sampled or not, therefore requires one +pass over the data, and provides an *expected* sample size. `sampleByKeyExact` requires significant +more resources than the per-stratum simple random sampling used in `sampleByKey`, but will provide +the exact sampling size with 99.99% confidence. `sampleByKeyExact` is currently not supported in +python. + +
+
+[`sampleByKeyExact()`](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions) allows users to +sample exactly $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired +fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of +keys. Sampling without replacement requires one additional pass over the RDD to guarantee sample +size, whereas sampling with replacement requires two additional passes. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.PairRDDFunctions + +val sc: SparkContext = ... + +val data = ... // an RDD[(K, V)] of any key value pairs +val fractions: Map[K, Double] = ... // specify the exact fraction desired from each key + +// Get an exact sample from each stratum +val approxSample = data.sampleByKey(withReplacement = false, fractions) +val exactSample = data.sampleByKeyExact(withReplacement = false, fractions) + +{% endhighlight %} +
+ +
+[`sampleByKeyExact()`](api/java/org/apache/spark/api/java/JavaPairRDD.html) allows users to +sample exactly $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired +fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of +keys. Sampling without replacement requires one additional pass over the RDD to guarantee sample +size, whereas sampling with replacement requires two additional passes. + +{% highlight java %} +import java.util.Map; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; + +JavaSparkContext jsc = ... + +JavaPairRDD data = ... // an RDD of any key value pairs +Map fractions = ... // specify the exact fraction desired from each key + +// Get an exact sample from each stratum +JavaPairRDD approxSample = data.sampleByKey(false, fractions); +JavaPairRDD exactSample = data.sampleByKeyExact(false, fractions); + +{% endhighlight %} +
+
+[`sampleByKey()`](api/python/pyspark.rdd.RDD-class.html#sampleByKey) allows users to +sample approximately $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the +desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the +set of keys. + +*Note:* `sampleByKeyExact()` is currently not supported in Python. + +{% highlight python %} + +sc = ... # SparkContext + +data = ... # an RDD of any key value pairs +fractions = ... # specify the exact fraction desired from each key as a dictionary + +approxSample = data.sampleByKey(False, fractions); + +{% endhighlight %} +
+ +
+ +## Hypothesis testing + +Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically +significant, whether this result occurred by chance or not. MLlib currently supports Pearson's +chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine +whether the goodness of fit or the independence test is conducted. The goodness of fit test requires +an input type of `Vector`, whereas the independence test requires a `Matrix` as input. + +MLlib also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared +independence tests. + +
+
+[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +hypothesis tests. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.Statistics._ + +val sc: SparkContext = ... + +val vec: Vector = ... // a vector composed of the frequencies of events + +// compute the goodness of fit. If a second vector to test against is not supplied as a parameter, +// the test runs against a uniform distribution. +val goodnessOfFitTestResult = Statistics.chiSqTest(vec) +println(goodnessOfFitTestResult) // summary of the test including the p-value, degrees of freedom, + // test statistic, the method used, and the null hypothesis. + +val mat: Matrix = ... // a contingency matrix + +// conduct Pearson's independence test on the input contingency matrix +val independenceTestResult = Statistics.chiSqTest(mat) +println(independenceTestResult) // summary of the test including the p-value, degrees of freedom... + +val obs: RDD[LabeledPoint] = ... // (feature, label) pairs. + +// The contingency table is constructed from the raw (feature, label) pairs and used to conduct +// the independence test. Returns an array containing the ChiSquaredTestResult for every feature +// against the label. +val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) +var i = 1 +featureTestResults.foreach { result => + println(s"Column $i:\n$result") + i += 1 +} // summary of the test + +{% endhighlight %} +
+ +
+[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +hypothesis tests. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.Statistics; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; + +JavaSparkContext jsc = ... + +Vector vec = ... // a vector composed of the frequencies of events + +// compute the goodness of fit. If a second vector to test against is not supplied as a parameter, +// the test runs against a uniform distribution. +ChiSqTestResult goodnessOfFitTestResult = Statistics.chiSqTest(vec); +// summary of the test including the p-value, degrees of freedom, test statistic, the method used, +// and the null hypothesis. +System.out.println(goodnessOfFitTestResult); + +Matrix mat = ... // a contingency matrix + +// conduct Pearson's independence test on the input contingency matrix +ChiSqTestResult independenceTestResult = Statistics.chiSqTest(mat); +// summary of the test including the p-value, degrees of freedom... +System.out.println(independenceTestResult); + +JavaRDD obs = ... // an RDD of labeled points + +// The contingency table is constructed from the raw (feature, label) pairs and used to conduct +// the independence test. Returns an array containing the ChiSquaredTestResult for every feature +// against the label. +ChiSqTestResult[] featureTestResults = Statistics.chiSqTest(obs.rdd()); +int i = 1; +for (ChiSqTestResult result : featureTestResults) { + System.out.println("Column " + i + ":"); + System.out.println(result); // summary of the test + i++; +} + +{% endhighlight %} +
+ +
+ +## Random data generation + +Random data generation is useful for randomized algorithms, prototyping, and performance testing. +MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +uniform, standard normal, or Poisson. + +
+
+[`RandomRDDs`](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.random.RandomRDDs._ + +val sc: SparkContext = ... + +// Generate a random double RDD that contains 1 million i.i.d. values drawn from the +// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +val u = normalRDD(sc, 1000000L, 10) +// Apply a transform to get a random double RDD following `N(1, 4)`. +val v = u.map(x => 1.0 + 2.0 * x) +{% endhighlight %} +
+ +
+[`RandomRDDs`](api/java/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight java %} +import org.apache.spark.SparkContext; +import org.apache.spark.api.JavaDoubleRDD; +import static org.apache.spark.mllib.random.RandomRDDs.*; + +JavaSparkContext jsc = ... + +// Generate a random double RDD that contains 1 million i.i.d. values drawn from the +// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +JavaDoubleRDD u = normalJavaRDD(jsc, 1000000L, 10); +// Apply a transform to get a random double RDD following `N(1, 4)`. +JavaDoubleRDD v = u.map( + new Function() { + public Double call(Double x) { + return 1.0 + 2.0 * x; + } + }); +{% endhighlight %} +
+ +
+[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight python %} +from pyspark.mllib.random import RandomRDDs + +sc = ... # SparkContext + +# Generate a random double RDD that contains 1 million i.i.d. values drawn from the +# standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +# Apply a transform to get a random double RDD following `N(1, 4)`. +v = u.map(lambda x: 1.0 + 2.0 * x) +{% endhighlight %} +
+ +
diff --git a/docs/mllib-stats.md b/docs/mllib-stats.md deleted file mode 100644 index f25dca746ba3a..0000000000000 --- a/docs/mllib-stats.md +++ /dev/null @@ -1,167 +0,0 @@ ---- -layout: global -title: Statistics Functionality - MLlib -displayTitle: MLlib - Statistics Functionality ---- - -* Table of contents -{:toc} - - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - -## Random data generation - -Random data generation is useful for randomized algorithms, prototyping, and performance testing. -MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: -uniform, standard normal, or Poisson. - -
-
-[`RandomRDDs`](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory -methods to generate random double RDDs or vector RDDs. -The following example generates a random double RDD, whose values follows the standard normal -distribution `N(0, 1)`, and then map it to `N(1, 4)`. - -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.random.RandomRDDs._ - -val sc: SparkContext = ... - -// Generate a random double RDD that contains 1 million i.i.d. values drawn from the -// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. -val u = normalRDD(sc, 1000000L, 10) -// Apply a transform to get a random double RDD following `N(1, 4)`. -val v = u.map(x => 1.0 + 2.0 * x) -{% endhighlight %} -
- -
-[`RandomRDDs`](api/java/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory -methods to generate random double RDDs or vector RDDs. -The following example generates a random double RDD, whose values follows the standard normal -distribution `N(0, 1)`, and then map it to `N(1, 4)`. - -{% highlight java %} -import org.apache.spark.SparkContext; -import org.apache.spark.api.JavaDoubleRDD; -import static org.apache.spark.mllib.random.RandomRDDs.*; - -JavaSparkContext jsc = ... - -// Generate a random double RDD that contains 1 million i.i.d. values drawn from the -// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. -JavaDoubleRDD u = normalJavaRDD(jsc, 1000000L, 10); -// Apply a transform to get a random double RDD following `N(1, 4)`. -JavaDoubleRDD v = u.map( - new Function() { - public Double call(Double x) { - return 1.0 + 2.0 * x; - } - }); -{% endhighlight %} -
- -
-[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory -methods to generate random double RDDs or vector RDDs. -The following example generates a random double RDD, whose values follows the standard normal -distribution `N(0, 1)`, and then map it to `N(1, 4)`. - -{% highlight python %} -from pyspark.mllib.random import RandomRDDs - -sc = ... # SparkContext - -# Generate a random double RDD that contains 1 million i.i.d. values drawn from the -# standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. -u = RandomRDDs.uniformRDD(sc, 1000000L, 10) -# Apply a transform to get a random double RDD following `N(1, 4)`. -v = u.map(lambda x: 1.0 + 2.0 * x) -{% endhighlight %} -
- -
- -## Stratified Sampling - -## Summary Statistics - -### Multivariate summary statistics - -We provide column summary statistics for `RowMatrix` (note: this functionality is not currently supported in `IndexedRowMatrix` or `CoordinateMatrix`). -If the number of columns is not large, e.g., on the order of thousands, then the -covariance matrix can also be computed as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the -number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, -and is faster if the rows are sparse. - -
-
- -[`computeColumnSummaryStatistics()`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of -[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary - -val mat: RowMatrix = ... // a RowMatrix - -// Compute column summary statistics. -val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() -println(summary.mean) // a dense vector containing the mean value for each column -println(summary.variance) // column-wise variance -println(summary.numNonzeros) // number of nonzeros in each column - -// Compute the covariance matrix. -val cov: Matrix = mat.computeCovariance() -{% endhighlight %} -
- -
- -[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of -[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight java %} -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; - -RowMatrix mat = ... // a RowMatrix - -// Compute column summary statistics. -MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); -System.out.println(summary.mean()); // a dense vector containing the mean value for each column -System.out.println(summary.variance()); // column-wise variance -System.out.println(summary.numNonzeros()); // number of nonzeros in each column - -// Compute the covariance matrix. -Matrix cov = mat.computeCovariance(); -{% endhighlight %} -
-
- - -## Hypothesis Testing diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 6ae780d94046a..624cc744dfd51 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -385,7 +385,7 @@ Apart from text files, Spark's Python API also supports several other data forma * SequenceFile and Hadoop Input/Output Formats -**Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on SparkSQL, in which case SparkSQL is the preferred approach. +**Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on Spark SQL, in which case Spark SQL is the preferred approach. **Writable Support** diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index bd046cfc1837d..1073abb202c56 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -107,7 +107,7 @@ cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooK The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: -1. In `spark.env.sh` set some environment variables: +1. In `spark-env.sh` set some environment variables: * `export MESOS_NATIVE_LIBRARY=`. This path is typically `/lib/libmesos.so` where the prefix is `/usr/local` by default. See Mesos installation instructions above. On Mac OS X, the library is called `libmesos.dylib` instead of @@ -165,6 +165,8 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +# Known issues +- When using the "fine-grained" mode, make sure that your executors always leave 32 MB free on the slaves. Otherwise it can happen that your Spark job does not proceed anymore. Currently, Apache Mesos only offers resources if there are at least 32 MB memory allocatable. But as Spark allocates memory only for the executor and cpu only for tasks, it can happen on high slave memory usage that no new tasks will be started anymore. More details can be found in [MESOS-1688](https://issues.apache.org/jira/browse/MESOS-1688). Alternatively use the "coarse-gained" mode, which is not affected by this issue. # Running Alongside Hadoop diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 9bc20dbf926b2..d8b22f3663d08 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -75,7 +75,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes (none) Comma-separated list of files to be placed in the working directory of each executor. - + spark.yarn.executor.memoryOverhead @@ -125,6 +125,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes the environment of the executor launcher. + + spark.yarn.containerLauncherMaxThreads + 25 + + The maximum number of threads to use in the application master for launching executor containers. + + # Launching Spark on YARN diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 34accade36ea9..d83efa4bab324 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -13,10 +13,10 @@ title: Spark SQL Programming Guide Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed -[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects along with +[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of +[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects, along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) +in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`. @@ -26,10 +26,10 @@ All of the examples on this page use sample data included in the Spark distribut
Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using Spark. At the core of this component is a new type of RDD, -[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed -[Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects along with +[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed of +[Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects, along with a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table -in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) +in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
@@ -37,10 +37,10 @@ file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive]( Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed -[Row](api/python/pyspark.sql.Row-class.html) objects along with +[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed of +[Row](api/python/pyspark.sql.Row-class.html) objects, along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) +in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell. @@ -68,6 +68,16 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.createSchemaRDD {% endhighlight %} +In addition to the basic SQLContext, you can also create a HiveContext, which provides a +superset of the functionality provided by the basic SQLContext. Additional features include +the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the +ability to read data from Hive tables. To use a HiveContext, you do not need to have an +existing Hive setup, and all of the data sources available to a SQLContext are still available. +HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using HiveContext +is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to +feature parity with a HiveContext. +
@@ -81,6 +91,16 @@ JavaSparkContext sc = ...; // An existing JavaSparkContext. JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); {% endhighlight %} +In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict +super set of the functionality provided by the basic SQLContext. Additional features include +the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the +ability to read data from Hive tables. To use a HiveContext, you do not need to have an +existing Hive setup, and all of the data sources available to a SQLContext are still available. +HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using HiveContext +is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to +feature parity with a HiveContext. +
@@ -94,36 +114,52 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) {% endhighlight %} +In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict +super set of the functionality provided by the basic SQLContext. Additional features include +the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the +ability to read data from Hive tables. To use a HiveContext, you do not need to have an +existing Hive setup, and all of the data sources available to a SQLContext are still available. +HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using HiveContext +is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to +feature parity with a HiveContext. +
+The specific variant of SQL that is used to parse queries can also be selected using the +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a SQLContext or by using a `SET key=value` command in SQL. For a SQLContext, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a HiveContext, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, + this is recommended for most use cases. + # Data Sources -
-
Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. -
- -
-Spark SQL supports operating on a variety of data sources through the `JavaSchemaRDD` interface. -Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. -
- -
-Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. -
-
+A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. +Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section +describes the various methods for loading data into a SchemaRDD. ## RDDs +Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first +method uses reflection to infer the schema of an RDD that contains specific types of objects. This +reflection based approach leads to more concise code and works well when you already know the schema +while writing your Spark application. + +The second method for creating SchemaRDDs is through a programmatic interface that allows you to +construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows +you to construct SchemaRDDs when the columns and their types are not known until runtime. + +### Inferring the Schema Using Reflection
-One type of table that is supported by Spark SQL is an RDD of Scala case classes. The case class +The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes +to a SchemaRDD. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be @@ -156,8 +192,9 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
-One type of table that is supported by Spark SQL is an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly). The BeanInfo -defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain +Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) +into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -192,7 +229,7 @@ for the JavaBean. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc) +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( @@ -229,24 +266,24 @@ List teenagerNames = teenagers.map(new Function() {
-One type of table that is supported by Spark SQL is an RDD of dictionaries. The keys of the -dictionary define the columns names of the table, and the types are inferred by looking at the first -row. Any RDD of dictionaries can converted to a SchemaRDD and then registered as a table. Tables -can be used in subsequent SQL statements. +Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes. Rows are constructed by passing a list of +key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, +and the types are inferred by looking at the first row. Since we currently only look at the first +row, it is important that there is no missing data in the first row of the RDD. In future versions we +plan to more completely infer the schema by looking at more data, similar to the inference that is +performed on JSON files. {% highlight python %} # sc is an existing SparkContext. -from pyspark.sql import SQLContext +from pyspark.sql import SQLContext, Row sqlContext = SQLContext(sc) # Load a text file and convert each line to a dictionary. lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) -people = parts.map(lambda p: {"name": p[0], "age": int(p[1])}) +people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) # Infer the schema, and register the SchemaRDD as a table. -# In future versions of PySpark we would like to add support for registering RDDs with other -# datatypes as tables schemaPeople = sqlContext.inferSchema(people) schemaPeople.registerTempTable("people") @@ -263,15 +300,191 @@ for teenName in teenNames.collect():
-**Note that Spark SQL currently uses a very basic SQL parser.** -Users that want a more complete dialect of SQL should look at the HiveQL support provided by -`HiveContext`. +### Programmatically Specifying the Schema + +
+ +
+ +When case classes cannot be defined ahead of time (for example, +the structure of records is encoded in a string, or a text dataset will be parsed +and fields will be projected differently for different users), +a `SchemaRDD` can be created programmatically with three steps. + +1. Create an RDD of `Row`s from the original RDD; +2. Create the schema represented by a `StructType` matching the structure of +`Row`s in the RDD created in Step 1. +3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +by `SQLContext`. + +For example: +{% highlight scala %} +// sc is an existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) + +// Create an RDD +val people = sc.textFile("examples/src/main/resources/people.txt") + +// The schema is encoded in a string +val schemaString = "name age" + +// Import Spark SQL data types and Row. +import org.apache.spark.sql._ + +// Generate the schema based on the string of schema +val schema = + StructType( + schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true))) + +// Convert records of the RDD (people) to Rows. +val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) + +// Apply the schema to the RDD. +val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema) + +// Register the SchemaRDD as a table. +peopleSchemaRDD.registerTempTable("people") + +// SQL statements can be run by using the sql methods provided by sqlContext. +val results = sqlContext.sql("SELECT name FROM people") + +// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The columns of a row in the result can be accessed by ordinal. +results.map(t => "Name: " + t(0)).collect().foreach(println) +{% endhighlight %} + + +
+ +
+ +When JavaBean classes cannot be defined ahead of time (for example, +the structure of records is encoded in a string, or a text dataset will be parsed and +fields will be projected differently for different users), +a `SchemaRDD` can be created programmatically with three steps. + +1. Create an RDD of `Row`s from the original RDD; +2. Create the schema represented by a `StructType` matching the structure of +`Row`s in the RDD created in Step 1. +3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +by `JavaSQLContext`. + +For example: +{% highlight java %} +// Import factory methods provided by DataType. +import org.apache.spark.sql.api.java.DataType +// Import StructType and StructField +import org.apache.spark.sql.api.java.StructType +import org.apache.spark.sql.api.java.StructField +// Import Row. +import org.apache.spark.sql.api.java.Row + +// sc is an existing JavaSparkContext. +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); + +// Load a text file and convert each line to a JavaBean. +JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); + +// The schema is encoded in a string +String schemaString = "name age"; + +// Generate the schema based on the string of schema +List fields = new ArrayList(); +for (String fieldName: schemaString.split(" ")) { + fields.add(DataType.createStructField(fieldName, DataType.StringType, true)); +} +StructType schema = DataType.createStructType(fields); + +// Convert records of the RDD (people) to Rows. +JavaRDD rowRDD = people.map( + new Function() { + public Row call(String record) throws Exception { + String[] fields = record.split(","); + return Row.create(fields[0], fields[1].trim()); + } + }); + +// Apply the schema to the RDD. +JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema); + +// Register the SchemaRDD as a table. +peopleSchemaRDD.registerTempTable("people"); + +// SQL can be run over RDDs that have been registered as tables. +JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people"); + +// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The columns of a row in the result can be accessed by ordinal. +List names = results.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0); + } +}).collect(); + +{% endhighlight %} + +
+ +
+ +When a dictionary of kwargs cannot be defined ahead of time (for example, +the structure of records is encoded in a string, or a text dataset will be parsed and +fields will be projected differently for different users), +a `SchemaRDD` can be created programmatically with three steps. + +1. Create an RDD of tuples or lists from the original RDD; +2. Create the schema represented by a `StructType` matching the structure of +tuples or lists in the RDD created in the step 1. +3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`. + +For example: +{% highlight python %} +# Import SQLContext and data types +from pyspark.sql import * + +# sc is an existing SparkContext. +sqlContext = SQLContext(sc) + +# Load a text file and convert each line to a tuple. +lines = sc.textFile("examples/src/main/resources/people.txt") +parts = lines.map(lambda l: l.split(",")) +people = parts.map(lambda p: (p[0], p[1].strip())) + +# The schema is encoded in a string. +schemaString = "name age" + +fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split()] +schema = StructType(fields) + +# Apply the schema to the RDD. +schemaPeople = sqlContext.applySchema(people, schema) + +# Register the SchemaRDD as a table. +schemaPeople.registerTempTable("people") + +# SQL can be run over SchemaRDDs that have been registered as a table. +results = sqlContext.sql("SELECT name FROM people") + +# The results of SQL queries are RDDs and support all the normal RDD operations. +names = results.map(lambda p: "Name: " + p.name) +for name in names.collect(): + print name +{% endhighlight %} + + +
+ +
## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. Using the data from the above example: +of the original data. + +### Loading Data Programmatically + +Using the data from the above example:
@@ -349,7 +562,40 @@ for teenName in teenNames.collect():
-
+
+ +### Configuration + +Configuration of Parquet can be done using the `setConf` method on SQLContext or by running +`SET key=value` commands using SQL. + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.parquet.binaryAsStringfalse + Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do + not differentiate between binary data and strings when writing out the Parquet schema. This + flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. +
spark.sql.parquet.cacheMetadatafalse + Turns on caching of Parquet schema metadata. Can speed up querying of static data. +
spark.sql.parquet.compression.codecsnappy + Sets the compression codec use when writing Parquet files. Acceptable values include: + uncompressed, snappy, gzip, lzo. +
## JSON Datasets
@@ -474,10 +720,10 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -In order to use Hive you must first run '`sbt/sbt -Phive assembly/assembly`' (or use `-Phive` for maven). +In order to use Hive you must first run "`sbt/sbt -Phive assembly/assembly`" (or use `-Phive` for maven). This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries -(SerDes) in order to acccess data stored in Hive. +(SerDes) in order to access data stored in Hive. Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. @@ -493,13 +739,13 @@ directory. {% highlight scala %} // sc is an existing SparkContext. -val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) +val sqlContext = new org.apache.spark.sql.hive.HiveContext(sc) -hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL -hiveContext.sql("FROM src SELECT key, value").collect().foreach(println) +sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) {% endhighlight %}
@@ -513,13 +759,13 @@ expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -JavaHiveContext hiveContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); +JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); -hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); // Queries are expressed in HiveQL. -Row[] results = hiveContext.sql("FROM src SELECT key, value").collect(); +Row[] results = sqlContext.sql("FROM src SELECT key, value").collect(); {% endhighlight %} @@ -535,52 +781,101 @@ expressed in HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext -hiveContext = HiveContext(sc) +sqlContext = HiveContext(sc) -hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = hiveContext.sql("FROM src SELECT key, value").collect() +results = sqlContext.sql("FROM src SELECT key, value").collect() {% endhighlight %}
-# Writing Language-Integrated Relational Queries +# Performance Tuning -**Language-Integrated queries are currently only supported in Scala.** +For some workloads it is possible to improve performance by either caching data in memory, or by +turning on some experimental options. -Spark SQL also supports a domain specific language for writing queries. Once again, -using the data from the above examples: +## Caching Data In Memory -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import sqlContext._ -val people: RDD[Person] = ... // An RDD of case class objects, from the first example. - -// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' -val teenagers = people.where('age >= 10).where('age <= 19).select('name) -teenagers.map(t => "Name: " + t(0)).collect().foreach(println) -{% endhighlight %} - -The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers -prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are -evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). +Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. +Then Spark SQL will scan only required columns and will automatically tune compression to minimize +memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. - +Note that if you call `cache` rather than `cacheTable`, tables will _not_ be cached using +the in-memory columnar format, and therefore `cacheTable` is strongly recommended for this use case. + +Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running +`SET key=value` commands using SQL. + + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.inMemoryColumnarStorage.compressedfalse + When set to true Spark SQL will automatically select a compression codec for each column based + on statistics of the data. +
spark.sql.inMemoryColumnarStorage.batchSize1000 + Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization + and compression, but risk OOMs when caching data. +
+ +## Other Configuration Options + +The following options can also be used to tune the performance of query execution. It is possible +that these options will be deprecated in future release as more optimizations are performed automatically. + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.autoBroadcastJoinThreshold10000 + Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when + performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently + statistics are only supported for Hive Metastore tables where the command + `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. +
spark.sql.codegenfalse + When true, code will be dynamically generated at runtime for expression evaluation in a specific + query. For some queries with complicated expression this option can lead to significant speed-ups. + However, for simple queries this can actually slow down query execution. +
spark.sql.shuffle.partitions200 + Configures the number of partitions to use when shuffling data for joins or aggregations. +
+ +# Other SQL Interfaces + +Spark SQL also supports interfaces for running SQL queries directly without the need to write any +code. ## Running the Thrift JDBC server -The Thrift JDBC server implemented here corresponds to the [`HiveServer2`] -(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test -the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive -you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver` -for maven). +The Thrift JDBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) +in Hive 0.12. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.12. To start the JDBC server, run the following in the Spark directory: @@ -599,55 +894,67 @@ Connect to the JDBC server in beeline with: Beeline will ask you for a username and password. In non-secure mode, simply enter the username on your machine and a blank password. For secure mode, please follow the instructions given in the -[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients) +[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients). Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. -You may also use the beeline script comes with Hive. +You may also use the beeline script that comes with Hive. +## Running the Spark SQL CLI + +The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute +queries input from the command line. Note that the Spark SQL CLI cannot talk to the Thrift JDBC server. + +To start the Spark SQL CLI, run the following in the Spark directory: + + ./bin/spark-sql + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +You may run `./bin/spark-sql --help` for a complete list of all available +options. + +# Compatibility with Other Systems + +## Migration Guide for Shark User + +### Scheduling +s To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, users can set the `spark.sql.thriftserver.scheduler.pool` variable: SET spark.sql.thriftserver.scheduler.pool=accounting; -### Migration Guide for Shark Users - -#### Reducer number +### Reducer number In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark -SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value +SQL deprecates this property in favor of `spark.sql.shuffle.partitions`, whose default value is 200. Users may customize this property via `SET`: -``` -SET spark.sql.shuffle.partitions=10; -SELECT page, count(*) c FROM logs_last_month_cached -GROUP BY page ORDER BY c DESC LIMIT 10; -``` + SET spark.sql.shuffle.partitions=10; + SELECT page, count(*) c + FROM logs_last_month_cached + GROUP BY page ORDER BY c DESC LIMIT 10; You may also put this property in `hive-site.xml` to override the default value. For now, the `mapred.reduce.tasks` property is still recognized, and is converted to `spark.sql.shuffle.partitions` automatically. -#### Caching +### Caching The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no -longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to +longer automatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to let user control table caching explicitly: -``` -CACHE TABLE logs_last_month; -UNCACHE TABLE logs_last_month; -``` + CACHE TABLE logs_last_month; + UNCACHE TABLE logs_last_month; -**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary", -but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be -cached, you may simply count the table immediately after executing `CACHE TABLE`: +**NOTE:** `CACHE TABLE tbl` is lazy, similar to `.cache` on an RDD. This command only marks `tbl` to ensure that +partitions are cached when calculated but doesn't actually cache it until a query that touches `tbl` is executed. +To force the table to be cached, you may simply count the table immediately after executing `CACHE TABLE`: -``` -CACHE TABLE logs_last_month; -SELECT COUNT(1) FROM logs_last_month; -``` + CACHE TABLE logs_last_month; + SELECT COUNT(1) FROM logs_last_month; Several caching related features are not supported yet: @@ -655,71 +962,75 @@ Several caching related features are not supported yet: * RDD reloading * In-memory cache write through policy -### Compatibility with Apache Hive +## Compatibility with Apache Hive -#### Deploying in Exising Hive Warehouses +Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark +SQL is based on Hive 0.12.0. -Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive +#### Deploying in Existing Hive Warehouses + +The Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive installations. You do not need to modify your existing Hive Metastore or change the data placement or partitioning of your tables. -#### Supported Hive Features +### Supported Hive Features Spark SQL supports the vast majority of Hive features, such as: * Hive query statements, including: - * `SELECT` - * `GROUP BY - * `ORDER BY` - * `CLUSTER BY` - * `SORT BY` + * `SELECT` + * `GROUP BY` + * `ORDER BY` + * `CLUSTER BY` + * `SORT BY` * All Hive operators, including: - * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc) - * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc) - * Logical operators (`AND`, `&&`, `OR`, `||`, etc) - * Complex type constructors - * Mathemtatical functions (`sign`, `ln`, `cos`, etc) - * String functions (`instr`, `length`, `printf`, etc) + * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc) + * Arithmetic operators (`+`, `-`, `*`, `/`, `%`, etc) + * Logical operators (`AND`, `&&`, `OR`, `||`, etc) + * Complex type constructors + * Mathematical functions (`sign`, `ln`, `cos`, etc) + * String functions (`instr`, `length`, `printf`, etc) * User defined functions (UDF) * User defined aggregation functions (UDAF) -* User defined serialization formats (SerDe's) +* User defined serialization formats (SerDes) * Joins - * `JOIN` - * `{LEFT|RIGHT|FULL} OUTER JOIN` - * `LEFT SEMI JOIN` - * `CROSS JOIN` + * `JOIN` + * `{LEFT|RIGHT|FULL} OUTER JOIN` + * `LEFT SEMI JOIN` + * `CROSS JOIN` * Unions -* Sub queries - * `SELECT col FROM ( SELECT a + b AS col from t1) t2` +* Sub-queries + * `SELECT col FROM ( SELECT a + b AS col from t1) t2` * Sampling * Explain * Partitioned tables * All Hive DDL Functions, including: - * `CREATE TABLE` - * `CREATE TABLE AS SELECT` - * `ALTER TABLE` + * `CREATE TABLE` + * `CREATE TABLE AS SELECT` + * `ALTER TABLE` * Most Hive Data types, including: - * `TINYINT` - * `SMALLINT` - * `INT` - * `BIGINT` - * `BOOLEAN` - * `FLOAT` - * `DOUBLE` - * `STRING` - * `BINARY` - * `TIMESTAMP` - * `ARRAY<>` - * `MAP<>` - * `STRUCT<>` - -#### Unsupported Hive Functionality + * `TINYINT` + * `SMALLINT` + * `INT` + * `BIGINT` + * `BOOLEAN` + * `FLOAT` + * `DOUBLE` + * `STRING` + * `BINARY` + * `TIMESTAMP` + * `ARRAY<>` + * `MAP<>` + * `STRUCT<>` + +### Unsupported Hive Functionality Below is a list of Hive features that we don't support yet. Most of these features are rarely used in Hive deployments. **Major Hive Features** +* Spark SQL does not currently support inserting to tables using dynamic partitioning. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. @@ -729,11 +1040,11 @@ in Hive deployments. have the same input format. * Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. -* `UNIONTYPE` +* `UNION` type and `DATE` type * Unique join * Single query multi insert * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at - the moment. + the moment and only supports populating the sizeInBytes field of the hive metastore. **Hive Input/Output Formats** @@ -743,7 +1054,7 @@ in Hive deployments. **Hive Optimizations** A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are -not necessary due to Spark SQL's in-memory computational model. Others are slotted for future +less important due to Spark SQL's in-memory computational model. Others are slotted for future releases of Spark SQL. * Block level bitmap indexes and virtual columns (used to build indexes) @@ -751,9 +1062,7 @@ releases of Spark SQL. Hive automatically converts the join into a map join. We are adding this auto conversion in the next release. * Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you - need to control the degree of parallelism post-shuffle using "SET - spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the - next release. + need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". * Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still launches tasks to compute the result. * Skew data flag: Spark SQL does not follow the skew data flags in Hive. @@ -762,25 +1071,471 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. -## Running the Spark SQL CLI +# Writing Language-Integrated Relational Queries -The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute -queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server. +**Language-Integrated queries are experimental and currently only supported in Scala.** -To start the Spark SQL CLI, run the following in the Spark directory: +Spark SQL also supports a domain specific language for writing queries. Once again, +using the data from the above examples: - ./bin/spark-sql +{% highlight scala %} +// sc is an existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) +// Importing the SQL context gives access to all the public SQL functions and implicit conversions. +import sqlContext._ +val people: RDD[Person] = ... // An RDD of case class objects, from the first example. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. -You may run `./bin/spark-sql --help` for a complete list of all available -options. +// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' +val teenagers = people.where('age >= 10).where('age <= 19).select('name) +teenagers.map(t => "Name: " + t(0)).collect().foreach(println) +{% endhighlight %} + +The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers +prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are +evaluated by the SQL execution engine. A full list of the functions supported can be found in the +[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). -# Cached tables + -Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. -Then Spark SQL will scan only required columns and will automatically tune compression to minimize -memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. +# Spark SQL DataType Reference + +* Numeric types + - `ByteType`: Represents 1-byte signed integer numbers. + The range of numbers is from `-128` to `127`. + - `ShortType`: Represents 2-byte signed integer numbers. + The range of numbers is from `-32768` to `32767`. + - `IntegerType`: Represents 4-byte signed integer numbers. + The range of numbers is from `-2147483648` to `2147483647`. + - `LongType`: Represents 8-byte signed integer numbers. + The range of numbers is from `-9223372036854775808` to `9223372036854775807`. + - `FloatType`: Represents 4-byte single-precision floating point numbers. + - `DoubleType`: Represents 8-byte double-precision floating point numbers. + - `DecimalType`: +* String type + - `StringType`: Represents character string values. +* Binary type + - `BinaryType`: Represents byte sequence values. +* Boolean type + - `BooleanType`: Represents boolean values. +* Datetime type + - `TimestampType`: Represents values comprising values of fields year, month, day, + hour, minute, and second. +* Complex types + - `ArrayType(elementType, containsNull)`: Represents values comprising a sequence of + elements with the type of `elementType`. `containsNull` is used to indicate if + elements in a `ArrayType` value can have `null` values. + - `MapType(keyType, valueType, valueContainsNull)`: + Represents values comprising a set of key-value pairs. The data type of keys are + described by `keyType` and the data type of values are described by `valueType`. + For a `MapType` value, keys are not allowed to have `null` values. `valueContainsNull` + is used to indicate if values of a `MapType` value can have `null` values. + - `StructType(fields)`: Represents values with the structure described by + a sequence of `StructField`s (`fields`). + * `StructField(name, dataType, nullable)`: Represents a field in a `StructType`. + The name of a field is indicated by `name`. The data type of a field is indicated + by `dataType`. `nullable` is used to indicate if values of this fields can have + `null` values. + +
+
+ +All data types of Spark SQL are located in the package `org.apache.spark.sql`. +You can access them by doing +{% highlight scala %} +import org.apache.spark.sql._ +{% endhighlight %} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Data typeValue type in ScalaAPI to access or create a data type
ByteType Byte + ByteType +
ShortType Short + ShortType +
IntegerType Int + IntegerType +
LongType Long + LongType +
FloatType Float + FloatType +
DoubleType Double + DoubleType +
DecimalType scala.math.sql.BigDecimal + DecimalType +
StringType String + StringType +
BinaryType Array[Byte] + BinaryType +
BooleanType Boolean + BooleanType +
TimestampType java.sql.Timestamp + TimestampType +
ArrayType scala.collection.Seq + ArrayType(elementType, [containsNull])
+ Note: The default value of containsNull is false. +
MapType scala.collection.Map + MapType(keyType, valueType, [valueContainsNull])
+ Note: The default value of valueContainsNull is true. +
StructType org.apache.spark.sql.Row + StructType(fields)
+ Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
StructField The value type in Scala of the data type of this field + (For example, Int for a StructField with the data type IntegerType) + StructField(name, dataType, nullable) +
+ +
+ +
+ +All data types of Spark SQL are located in the package of +`org.apache.spark.sql.api.java`. To access or create a data type, +please use factory methods provided in +`org.apache.spark.sql.api.java.DataType`. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Data typeValue type in JavaAPI to access or create a data type
ByteType byte or Byte + DataType.ByteType +
ShortType short or Short + DataType.ShortType +
IntegerType int or Integer + DataType.IntegerType +
LongType long or Long + DataType.LongType +
FloatType float or Float + DataType.FloatType +
DoubleType double or Double + DataType.DoubleType +
DecimalType java.math.BigDecimal + DataType.DecimalType +
StringType String + DataType.StringType +
BinaryType byte[] + DataType.BinaryType +
BooleanType boolean or Boolean + DataType.BooleanType +
TimestampType java.sql.Timestamp + DataType.TimestampType +
ArrayType java.util.List + DataType.createArrayType(elementType)
+ Note: The value of containsNull will be false
+ DataType.createArrayType(elementType, containsNull). +
MapType java.util.Map + DataType.createMapType(keyType, valueType)
+ Note: The value of valueContainsNull will be true.
+ DataType.createMapType(keyType, valueType, valueContainsNull)
+
StructType org.apache.spark.sql.api.java + DataType.createStructType(fields)
+ Note: fields is a List or an array of StructFields. + Also, two fields with the same name are not allowed. +
StructField The value type in Java of the data type of this field + (For example, int for a StructField with the data type IntegerType) + DataType.createStructField(name, dataType, nullable) +
+ +
+ +
+ +All data types of Spark SQL are located in the package of `pyspark.sql`. +You can access them by doing +{% highlight python %} +from pyspark.sql import * +{% endhighlight %} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Data typeValue type in PythonAPI to access or create a data type
ByteType + int or long
+ Note: Numbers will be converted to 1-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -128 to 127. +
+ ByteType() +
ShortType + int or long
+ Note: Numbers will be converted to 2-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -32768 to 32767. +
+ ShortType() +
IntegerType int or long + IntegerType() +
LongType + long
+ Note: Numbers will be converted to 8-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of + -9223372036854775808 to 9223372036854775807. + Otherwise, please convert data to decimal.Decimal and use DecimalType. +
+ LongType() +
FloatType + float
+ Note: Numbers will be converted to 4-byte single-precision floating + point numbers at runtime. +
+ FloatType() +
DoubleType float + DoubleType() +
DecimalType decimal.Decimal + DecimalType() +
StringType string + StringType() +
BinaryType bytearray + BinaryType() +
BooleanType bool + BooleanType() +
TimestampType datetime.datetime + TimestampType() +
ArrayType list, tuple, or array + ArrayType(elementType, [containsNull])
+ Note: The default value of containsNull is False. +
MapType dict + MapType(keyType, valueType, [valueContainsNull])
+ Note: The default value of valueContainsNull is True. +
StructType list or tuple + StructType(fields)
+ Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
StructField The value type in Python of the data type of this field + (For example, Int for a StructField with the data type IntegerType) + StructField(name, dataType, nullable) +
+ +
+ +
-Note that if you just call `cache` rather than `cacheTable`, tables will _not_ be cached in -in-memory columnar format. So we strongly recommend using `cacheTable` whenever you want to -cache tables. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md new file mode 100644 index 0000000000000..c39ef1ce59e1c --- /dev/null +++ b/docs/storage-openstack-swift.md @@ -0,0 +1,152 @@ +--- +layout: global +title: Accessing OpenStack Swift from Spark +--- + +Spark's support for Hadoop InputFormat allows it to process data in OpenStack Swift using the +same URI formats as in Hadoop. You can specify a path in Swift as input through a +URI of the form swift://container.PROVIDER/path. You will also need to set your +Swift security credentials, through core-site.xml or via +SparkContext.hadoopConfiguration. +Current Swift driver requires Swift to use Keystone authentication method. + +# Configuring Swift for Better Data Locality + +Although not mandatory, it is recommended to configure the proxy server of Swift with +list_endpoints to have better data locality. More information is +[available here](https://github.com/openstack/swift/blob/master/swift/common/middleware/list_endpoints.py). + + +# Dependencies + +The Spark application should include hadoop-openstack dependency. +For example, for Maven support, add the following to the pom.xml file: + +{% highlight xml %} + + ... + + org.apache.hadoop + hadoop-openstack + 2.3.0 + + ... + +{% endhighlight %} + + +# Configuration Parameters + +Create core-site.xml and place it inside Spark's conf directory. +There are two main categories of parameters that should to be configured: declaration of the +Swift driver and the parameters that are required by Keystone. + +Configuration of Hadoop to use Swift File system achieved via + + + + + + + +
Property NameValue
fs.swift.implorg.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem
+ +Additional parameters required by Keystone (v2.0) and should be provided to the Swift driver. Those +parameters will be used to perform authentication in Keystone to access Swift. The following table +contains a list of Keystone mandatory parameters. PROVIDER can be any name. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Property NameMeaningRequired
fs.swift.service.PROVIDER.auth.urlKeystone Authentication URLMandatory
fs.swift.service.PROVIDER.auth.endpoint.prefixKeystone endpoints prefixOptional
fs.swift.service.PROVIDER.tenantTenantMandatory
fs.swift.service.PROVIDER.usernameUsernameMandatory
fs.swift.service.PROVIDER.passwordPasswordMandatory
fs.swift.service.PROVIDER.http.portHTTP portMandatory
fs.swift.service.PROVIDER.regionKeystone regionMandatory
fs.swift.service.PROVIDER.publicIndicates if all URLs are publicMandatory
+ +For example, assume PROVIDER=SparkTest and Keystone contains user tester with password testing +defined for tenant test. Then core-site.xml should include: + +{% highlight xml %} + + + fs.swift.impl + org.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem + + + fs.swift.service.SparkTest.auth.url + http://127.0.0.1:5000/v2.0/tokens + + + fs.swift.service.SparkTest.auth.endpoint.prefix + endpoints + + fs.swift.service.SparkTest.http.port + 8080 + + + fs.swift.service.SparkTest.region + RegionOne + + + fs.swift.service.SparkTest.public + true + + + fs.swift.service.SparkTest.tenant + test + + + fs.swift.service.SparkTest.username + tester + + + fs.swift.service.SparkTest.password + testing + + +{% endhighlight %} + +Notice that +fs.swift.service.PROVIDER.tenant, +fs.swift.service.PROVIDER.username, +fs.swift.service.PROVIDER.password contains sensitive information and keeping them in +core-site.xml is not always a good approach. +We suggest to keep those parameters in core-site.xml for testing purposes when running Spark +via spark-shell. +For job submissions they should be provided via sparkContext.hadoopConfiguration. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md new file mode 100644 index 0000000000000..d57c3e0ef9ba0 --- /dev/null +++ b/docs/streaming-flume-integration.md @@ -0,0 +1,132 @@ +--- +layout: global +title: Spark Streaming + Flume Integration Guide +--- + +[Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. + +## Approach 1: Flume-style Push-based Approach +Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. + +#### General Requirements +Choose a machine in your cluster such that + +- When your Flume + Spark Streaming application is launched, one of the Spark workers must run on that machine. + +- Flume can be configured to push data to a port on that machine. + +Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able push data. + +#### Configuring Flume +Configure Flume agent to send data to an Avro sink by having the following in the configuration file. + + agent.sinks = avroSink + agent.sinks.avroSink.type = avro + agent.sinks.avroSink.channel = memoryChannel + agent.sinks.avroSink.hostname = + agent.sinks.avroSink.port = + +See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about +configuring Flume agents. + +#### Configuring Spark Streaming Application +1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-flume_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `FlumeUtils` and create input DStream as follows. + +
+
+ import org.apache.spark.streaming.flume._ + + val flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala). +
+
+ import org.apache.spark.streaming.flume.*; + + JavaReceiverInputDStream flumeStream = + FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). +
+
+ + Note that the hostname should be the same as the one used by the resource manager in the + cluster (Mesos, YARN or Spark Standalone), so that resource allocation can match the names and launch + the receiver in the right machine. + +3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + +## Approach 2 (Experimental): Pull-based Approach using a Custom Sink +Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. +- Flume pushes data into the sink, and the data stays buffered. +- Spark Streaming uses transactions to pull data from the sink. Transactions succeed only after data is received and replicated by Spark Streaming. +This ensures that better reliability and fault-tolerance than the previous approach. However, this requires configuring Flume to run a custom sink. Here are the configuration steps. + +#### General Requirements +Choose a machine that will run the custom sink in a Flume agent. The rest of the Flume pipeline is configured to send data to that agent. Machines in the Spark cluster should have access to the chosen machine running the custom sink. + +#### Configuring Flume +Configuring Flume on the chosen machine requires the following two steps. + +1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink . + + (i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)). + + groupId = org.apache.spark + artifactId = spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + + (ii) *Scala library JAR*: Download the Scala library JAR for Scala {{site.SCALA_VERSION}}. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/scala-lang/scala-library/{{site.SCALA_VERSION}}/scala-library-{{site.SCALA_VERSION}}.jar)). + + groupId = org.scala-lang + artifactId = scala-library + version = {{site.SCALA_VERSION}} + +2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. + + agent.sinks = spark + agent.sinks.spark.type = org.apache.spark.streaming.flume.sink.SparkSink + agent.sinks.spark.hostname = + agent.sinks.spark.port = + agent.sinks.spark.channel = memoryChannel + + Also make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. + +See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about +configuring Flume agents. + +#### Configuring Spark Streaming Application +1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide). + +2. **Programming:** In the streaming application code, import `FlumeUtils` and create input DStream as follows. + +
+
+ import org.apache.spark.streaming.flume._ + + val flumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]) +
+
+ import org.apache.spark.streaming.flume.*; + + JavaReceiverInputDStreamflumeStream = + FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
+
+ + See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). + + Note that each input DStream can be configured to receive data from multiple sinks. + +3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + + diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md new file mode 100644 index 0000000000000..a3b705d4c31d0 --- /dev/null +++ b/docs/streaming-kafka-integration.md @@ -0,0 +1,42 @@ +--- +layout: global +title: Spark Streaming + Kafka Integration Guide +--- +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. + +1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create input DStream as follows. + +
+
+ import org.apache.spark.streaming.kafka._ + + val kafkaStream = KafkaUtils.createStream( + streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). +
+
+ import org.apache.spark.streaming.kafka.*; + + JavaPairReceiverInputDStream kafkaStream = KafkaUtils.createStream( + streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). +
+
+ + *Points to remember:* + + - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + + - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. + +3. **Deploying:** Package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md new file mode 100644 index 0000000000000..c6090d9ec30c7 --- /dev/null +++ b/docs/streaming-kinesis-integration.md @@ -0,0 +1,150 @@ +--- +layout: global +title: Spark Streaming + Kinesis Integration +--- +[Amazon Kinesis](http://aws.amazon.com/kinesis/) is a fully managed service for real-time processing of streaming data at massive scale. +The Kinesis receiver creates an input DStream using the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License (ASL). +The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concepts of Workers, Checkpoints, and Shard Leases. +Here we explain how to configure Spark Streaming to receive data from Kinesis. + +#### Configuring Kinesis + +A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or more shards per the following +[guide](http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html). + + +#### Configuring Spark Streaming Application + +1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + + **Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your application.** + +2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream as follows: + +
+
+ import org.apache.spark.streaming.Duration + import org.apache.spark.streaming.kinesis._ + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + val kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. + +
+
+ import org.apache.spark.streaming.Duration; + import org.apache.spark.streaming.kinesis.*; + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + + JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + +
+
+ + - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + + - `[Kinesis stream name]`: The Kinesis stream that this streaming application receives from + - The application name used in the streaming context becomes the Kinesis application name + - The application name must be unique for a given account and region. + - The Kinesis backend automatically associates the application name to the Kinesis stream using a DynamoDB table (always in the us-east-1 region) created during Kinesis Client Library initialization. + - Changing the application name or stream name can lead to Kinesis errors in some cases. If you see errors, you may need to manually delete the DynamoDB table. + + + - `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region). + + - `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application. + + - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + + +3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + *Points to remember at runtime:* + + - Kinesis data processing is ordered per partition and occurs at-least once per message. + + - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamodDB. + + - A single Kinesis stream shard is processed by one input DStream at a time. + +

+ Spark Streaming Kinesis Architecture + +

+ + - A single Kinesis input DStream can read from multiple shards of a Kinesis stream by creating multiple KinesisRecordProcessor threads. + + - Multiple input DStreams running in separate processes/instances can read from a Kinesis stream. + + - You never need more Kinesis input DStreams than the number of Kinesis stream shards as each input DStream will create at least one KinesisRecordProcessor thread that handles a single shard. + + - Horizontal scaling is achieved by adding/removing Kinesis input DStreams (within a single process or across multiple processes/instances) - up to the total number of Kinesis stream shards per the previous point. + + - The Kinesis input DStream will balance the load between all DStreams - even across processes/instances. + + - The Kinesis input DStream will balance the load during re-shard events (merging and splitting) due to changes in load. + + - As a best practice, it's recommended that you avoid re-shard jitter by over-provisioning when possible. + + - Each Kinesis input DStream maintains its own checkpoint info. See the Kinesis Checkpointing section for more details. + + - There is no correlation between the number of Kinesis stream shards and the number of RDD partitions/shards created across the Spark cluster during input DStream processing. These are 2 independent partitioning schemes. + +#### Running the Example +To run the example, + +- Download Spark source and follow the [instructions](building-with-maven.html) to build Spark with profile *-Pkinesis-asl*. + + mvn -Pkinesis-asl -DskipTests clean package + + +- Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. + +- Set up the environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_KEY with your AWS credentials. + +- In the Spark root directory, run the example as + +
+
+ + bin/run-example streaming.KinesisWordCountASL [Kinesis stream name] [endpoint URL] + +
+
+ + bin/run-example streaming.JavaKinesisWordCountASL [Kinesis stream name] [endpoint URL] + +
+
+ + This will wait for data to be received from the Kinesis stream. + +- To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer. + + bin/run-example streaming.KinesisWordCountProducerASL [Kinesis stream name] [endpoint URL] 1000 10 + + This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. + +#### Kinesis Checkpointing +- Each Kinesis input DStream periodically stores the current position of the stream in the backing DynamoDB table. This allows the system to recover from failures and continue processing where the DStream left off. + +- Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy. + +- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable. +- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). +- InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md deleted file mode 100644 index 16ad3222105a2..0000000000000 --- a/docs/streaming-kinesis.md +++ /dev/null @@ -1,59 +0,0 @@ ---- -layout: global -title: Spark Streaming Kinesis Receiver ---- - -## Kinesis -###Design -
  • The KinesisReceiver uses the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License.
  • -
  • The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concept of Workers, Checkpoints, and Shard Leases.
  • -
  • The KCL uses DynamoDB to maintain all state. A DynamoDB table is created in the us-east-1 region (regardless of Kinesis stream region) during KCL initialization for each Kinesis application name.
  • -
  • A single KinesisReceiver can process many shards of a stream by spinning up multiple KinesisRecordProcessor threads.
  • -
  • You never need more KinesisReceivers than the number of shards in your stream as each will spin up at least one KinesisRecordProcessor thread.
  • -
  • Horizontal scaling is achieved by autoscaling additional KinesisReceiver (separate processes) or spinning up new KinesisRecordProcessor threads within each KinesisReceiver - up to the number of current shards for a given stream, of course. Don't forget to autoscale back down!
  • - -### Build -
  • Spark supports a Streaming KinesisReceiver, but it is not included in the default build due to Amazon Software Licensing (ASL) restrictions.
  • -
  • To build with the Kinesis Streaming Receiver and supporting ASL-licensed code, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • -
  • All KinesisReceiver-related code, examples, tests, and artifacts live in **$SPARK_HOME/extras/kinesis-asl/**.
  • -
  • Kinesis-based Spark Applications will need to link to the **spark-streaming-kinesis-asl** artifact that is built when **-Pkinesis-asl** is specified.
  • -
  • _**Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • - -###Example -
  • To build the Kinesis example, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • -
  • You need to setup a Kinesis stream at one of the valid Kinesis endpoints with 1 or more shards per the following: http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • -
  • Valid Kinesis endpoints can be found here: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • -
  • When running **locally**, the example automatically determines the number of threads and KinesisReceivers to spin up based on the number of shards configured for the stream. Therefore, **local[n]** is not needed when starting the example as with other streaming examples.
  • -
  • While this example could use a single KinesisReceiver which spins up multiple KinesisRecordProcessor threads to process multiple shards, I wanted to demonstrate unioning multiple KinesisReceivers as a single DStream. (It's a bit confusing in local mode.)
  • -
  • **KinesisWordCountProducerASL** is provided to generate random records into the Kinesis stream for testing.
  • -
  • The example has been configured to immediately replicate incoming stream data to another node by using (StorageLevel.MEMORY_AND_DISK_2) -
  • Spark checkpointing is disabled because the example does not use any stateful or window-based DStream operations such as updateStateByKey and reduceByWindow. If those operations are introduced, you would need to enable checkpointing or risk losing data in the case of a failure.
  • -
  • Kinesis checkpointing is enabled. This means that the example will recover from a Kinesis failure.
  • -
  • The example uses InitialPositionInStream.LATEST strategy to pull from the latest tip of the stream if no Kinesis checkpoint info exists.
  • -
  • In our example, **KinesisWordCount** is the Kinesis application name for both the Scala and Java versions. The use of this application name is described next.
  • - -###Deployment and Runtime -
  • A Kinesis application name must be unique for a given account and region.
  • -
  • A DynamoDB table and CloudWatch namespace are created during KCL initialization using this Kinesis application name. http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization
  • -
  • This DynamoDB table lives in the us-east-1 region regardless of the Kinesis endpoint URL.
  • -
  • Changing the app name or stream name could lead to Kinesis errors as only a single logical application can process a single stream.
  • -
  • If you are seeing errors after changing the app name or stream name, it may be necessary to manually delete the DynamoDB table and start from scratch.
  • -
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the KCL.
  • -
  • The KinesisReceiver uses the DefaultAWSCredentialsProviderChain for AWS credentials which searches for credentials in the following order of precedence:
    -1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    -2) Java System Properties - aws.accessKeyId and aws.secretKey
    -3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    -4) Instance profile credentials - delivered through the Amazon EC2 metadata service -
  • - -###Fault-Tolerance -
  • The combination of Spark Streaming and Kinesis creates 2 different checkpoints that may occur at different intervals.
  • -
  • Checkpointing too frequently against Kinesis will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random backoff retry strategy.
  • -
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last Kinesis checkpoint sequence number recorded per shard (stored in the DynamoDB table).
  • -
  • If no Kinesis checkpoint info exists, the KinesisReceiver will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable.
  • -
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running (and no checkpoint info is being stored.)
  • -
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data.
  • -
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency.
  • -
  • Record processing should be idempotent when possible.
  • -
  • A failed or latent KinesisRecordProcessor within the KinesisReceiver will be detected and automatically restarted by the KCL.
  • -
  • If possible, the KinesisReceiver should be shutdown cleanly in order to trigger a final checkpoint of all KinesisRecordProcessors to avoid duplicate record processing.
  • \ No newline at end of file diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 9f331ed50d2a4..41f170580f452 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -7,12 +7,12 @@ title: Spark Streaming Programming Guide {:toc} # Overview -Spark Streaming is an extension of the core Spark API that allows enables high-throughput, +Spark Streaming is an extension of the core Spark API that allows enables scalable, high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources like Kafka, Flume, Twitter, ZeroMQ, Kinesis or plain old TCP sockets and be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, -and live dashboards. In fact, you can apply Spark's in-built +and live dashboards. In fact, you can apply Spark's [machine learning](mllib-guide.html) algorithms, and [graph processing](graphx-programming-guide.html) algorithms on data streams. @@ -60,35 +60,24 @@ do is as follows.
    First, we import the names of the Spark Streaming classes, and some implicit conversions from StreamingContext into our environment, to add useful methods to -other classes we need (like DStream). - -[StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the -main entry point for all streaming functionality. +other classes we need (like DStream). [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the +main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. {% highlight scala %} +import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -{% endhighlight %} - -Then we create a -[StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) object. -Besides Spark's configuration, we specify that any DStream will be processed -in 1 second batches. -{% highlight scala %} -import org.apache.spark.api.java.function._ -import org.apache.spark.streaming._ -import org.apache.spark.streaming.api._ -// Create a StreamingContext with a local master -// Spark Streaming needs at least two working thread -val ssc = new StreamingContext("local[2]", "NetworkWordCount", Seconds(1)) +// Create a local StreamingContext with two working thread and batch interval of 1 second +val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") +val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} -Using this context, we then create a new DStream -by specifying the IP address and port of the data server. +Using this context, we can create a DStream that represents streaming data from a TCP +source hostname, e.g. `localhost`, and port, e.g. `9999` {% highlight scala %} -// Create a DStream that will connect to serverIP:serverPort, like localhost:9999 +// Create a DStream that will connect to hostname:port, like localhost:9999 val lines = ssc.socketTextStream("localhost", 9999) {% endhighlight %} @@ -112,7 +101,7 @@ import org.apache.spark.streaming.StreamingContext._ val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) -// Print a few of the counts to the console +// Print the first ten elements of each RDD generated in this DStream to the console wordCounts.print() {% endhighlight %} @@ -139,23 +128,25 @@ The complete code can be found in the Spark Streaming example First, we create a [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) object, which is the main entry point for all streaming -functionality. Besides Spark's configuration, we specify that any DStream would be processed -in 1 second batches. +functionality. We create a local StreamingContext with two execution threads, and a batch interval of 1 second. {% highlight java %} +import org.apache.spark.*; import org.apache.spark.api.java.function.*; import org.apache.spark.streaming.*; import org.apache.spark.streaming.api.java.*; import scala.Tuple2; -// Create a StreamingContext with a local master -JavaStreamingContext jssc = new JavaStreamingContext("local[2]", "JavaNetworkWordCount", new Duration(1000)) + +// Create a local StreamingContext with two working thread and batch interval of 1 second +val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") +JavaStreamingContext jssc = new JavaStreamingContext(conf, new Duration(1000)) {% endhighlight %} -Using this context, we then create a new DStream -by specifying the IP address and port of the data server. +Using this context, we can create a DStream that represents streaming data from a TCP +source hostname, e.g. `localhost`, and port, e.g. `9999` {% highlight java %} -// Create a DStream that will connect to serverIP:serverPort, like localhost:9999 +// Create a DStream that will connect to hostname:port, like localhost:9999 JavaReceiverInputDStream lines = jssc.socketTextStream("localhost", 9999); {% endhighlight %} @@ -197,7 +188,9 @@ JavaPairDStream wordCounts = pairs.reduceByKey( return i1 + i2; } }); -wordCounts.print(); // Print a few of the counts to the console + +// Print the first ten elements of each RDD generated in this DStream to the console +wordCounts.print(); {% endhighlight %} The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word, @@ -207,8 +200,8 @@ using a [Function2](api/scala/index.html#org.apache.spark.api.java.function.Func Finally, `wordCounts.print()` will print a few of the counts generated every second. Note that when these lines are executed, Spark Streaming only sets up the computation it -will perform when it is started, and no real processing has started yet. To start the processing -after all the transformations have been setup, we finally call +will perform after it is started, and no real processing has started yet. To start the processing +after all the transformations have been setup, we finally call `start` method. {% highlight java %} jssc.start(); // Start the computation @@ -235,12 +228,12 @@ Then, in a different terminal, you can start the example by using
    {% highlight bash %} -$ ./bin/run-example org.apache.spark.examples.streaming.NetworkWordCount localhost 9999 +$ ./bin/run-example streaming.NetworkWordCount localhost 9999 {% endhighlight %}
    {% highlight bash %} -$ ./bin/run-example org.apache.spark.examples.streaming.JavaNetworkWordCount localhost 9999 +$ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 {% endhighlight %}
    @@ -269,7 +262,7 @@ hello world {% highlight bash %} # TERMINAL 2: RUNNING NetworkWordCount or JavaNetworkWordCount -$ ./bin/run-example org.apache.spark.examples.streaming.NetworkWordCount localhost 9999 +$ ./bin/run-example streaming.NetworkWordCount localhost 9999 ... ------------------------------------------- Time: 1357008430000 ms @@ -281,37 +274,33 @@ Time: 1357008430000 ms -You can also use Spark Streaming directly from the Spark shell: - -{% highlight bash %} -$ bin/spark-shell -{% endhighlight %} - -... and create your StreamingContext by wrapping the existing interactive shell -SparkContext object, `sc`: - -{% highlight scala %} -val ssc = new StreamingContext(sc, Seconds(1)) -{% endhighlight %} - -When working with the shell, you may also need to send a `^D` to your netcat session -to force the pipeline to print the word counts to the console at the sink. -*************************************************************************************************** +*************************************************************************************************** +*************************************************************************************************** -# Basics +# Basic Concepts Next, we move beyond the simple example and elaborate on the basics of Spark Streaming that you need to know to write your streaming applications. ## Linking -To write your own Spark Streaming program, you will have to add the following dependency to your - SBT or Maven project: +Similar to Spark, Spark Streaming is available through Maven Central. To write your own Spark Streaming program, you will have to add the following dependency to your SBT or Maven project. + +
    +
    + + + org.apache.spark + spark-streaming_{{site.SCALA_BINARY_VERSION}} + {{site.SPARK_VERSION}} + +
    +
    - groupId = org.apache.spark - artifactId = spark-streaming_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION}} + libraryDependencies += "org.apache.spark" % "spark-streaming_{{site.SCALA_BINARY_VERSION}}" % "{{site.SPARK_VERSION}}" +
    +
    For ingesting data from sources like Kafka, Flume, and Kinesis that are not present in the Spark Streaming core @@ -319,68 +308,120 @@ Streaming core artifact `spark-streaming-xyz_{{site.SCALA_BINARY_VERSION}}` to the dependencies. For example, some of the common ones are as follows. - + - - +
    SourceArtifact
    Kafka spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}
    Flume spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}
    Kinesis
    spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} [Apache Software License]
    Twitter spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}
    ZeroMQ spark-streaming-zeromq_{{site.SCALA_BINARY_VERSION}}
    MQTT spark-streaming-mqtt_{{site.SCALA_BINARY_VERSION}}
    Kinesis
    (built separately)
    kinesis-asl_{{site.SCALA_BINARY_VERSION}}
    For an up-to-date list, please refer to the -[Apache repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION}}%22) +[Apache repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) for the full list of supported sources and artifacts. -## Initializing +*** + +## Initializing StreamingContext + +To initialize a Spark Streaming program, a **StreamingContext** object has to be created which is the main entry point of all Spark Streaming functionality.
    -To initialize a Spark Streaming program in Scala, a -[`StreamingContext`](api/scala/index.html#org.apache.spark.streaming.StreamingContext) -object has to be created, which is the main entry point of all Spark Streaming functionality. -A `StreamingContext` object can be created by using +A [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) object can be created from a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object. {% highlight scala %} -new StreamingContext(master, appName, batchDuration, [sparkHome], [jars]) +import org.apache.spark._ +import org.apache.spark.streaming._ + +val conf = new SparkConf().setAppName(appName).setMaster(master) +val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} -
    -
    -To initialize a Spark Streaming program in Java, a -[`JavaStreamingContext`](api/scala/index.html#org.apache.spark.streaming.api.java.JavaStreamingContext) -object has to be created, which is the main entry point of all Spark Streaming functionality. -A `JavaStreamingContext` object can be created by using +The `appName` parameter is a name for your application to show on the cluster UI. +`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls), +or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster, +you will not want to hardcode `master` in the program, +but rather [launch the application with `spark-submit`](submitting-applications.html) and +receive it there. However, for local testing and unit tests, you can pass "local[\*]" to run Spark Streaming +in-process (detects the number of cores in the local system). Note that this internally creates a [SparkContext](api/scala/index.html#org.apache.spark.SparkContext) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. + +The batch interval must be set based on the latency requirements of your application +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +section for more details. + +A `StreamingContext` object can also be created from an existing `SparkContext` object. {% highlight scala %} -new JavaStreamingContext(master, appName, batchInterval, [sparkHome], [jars]) +import org.apache.spark.streaming._ + +val sc = ... // existing SparkContext +val ssc = new StreamingContext(sc, Seconds(1)) {% endhighlight %} + +
    -
    +
    + +A [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) object can be created from a [SparkConf](api/java/index.html?org/apache/spark/SparkConf.html) object. + +{% highlight java %} +import org.apache.spark.*; +import org.apache.spark.streaming.api.java.*; -The `master` parameter is a standard [Spark cluster URL](programming-guide.html#master-urls) -and can be "local" for local testing. The `appName` is a name of your program, -which will be shown on your cluster's web UI. The `batchInterval` is the size of the batches, -as explained earlier. Finally, the last two parameters are needed to deploy your code to a cluster - if running in distributed mode, as described in the - [Spark programming guide](programming-guide.html#deploying-code-on-a-cluster). - Additionally, the underlying SparkContext can be accessed as -`ssc.sparkContext`. +SparkConf conf = new SparkConf().setAppName(appName).setMaster(master); +JavaStreamingContext ssc = new JavaStreamingContext(conf, Duration(1000)); +{% endhighlight %} + +The `appName` parameter is a name for your application to show on the cluster UI. +`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls), +or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster, +you will not want to hardcode `master` in the program, +but rather [launch the application with `spark-submit`](submitting-applications.html) and +receive it there. However, for local testing and unit tests, you can pass "local[*]" to run Spark Streaming +in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) section for more details. -## DStreams -*Discretized Stream* or *DStream* is the basic abstraction provided by Spark Streaming. +A `JavaStreamingContext` object can also be created from an existing `JavaSparkContext`. + +{% highlight java %} +import org.apache.spark.streaming.api.java.*; + +JavaSparkContext sc = ... //existing JavaSparkContext +JavaStreamingContext ssc = new JavaStreamingContext(sc, new Duration(1000)); +{% endhighlight %} +
    +
    + +After a context is defined, you have to do the follow steps. +1. Define the input sources. +1. Setup the streaming computations. +1. Start the receiving and procesing of data using `streamingContext.start()`. +1. The processing will continue until `streamingContext.stop()` is called. + +##### Points to remember: +{:.no_toc} +- Once a context has been started, no new streaming computations can be setup or added to it. +- Once a context has been stopped, it cannot be started (that is, re-used) again. +- Only one StreamingContext can be active in a JVM at the same time. +- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set optional parameter of `stop()` called `stopSparkContext` to false. +- A SparkContext can be re-used to create multiple StreamingContexts, as long as the previous StreamingContext is stopped (without stopping the SparkContext) before the next StreamingContext is created. + +*** + +## Discretized Streams (DStreams) +**Discretized Stream** or **DStream** is the basic abstraction provided by Spark Streaming. It represents a continuous stream of data, either the input data stream received from source, or the processed data stream generated by transforming the input stream. Internally, -it is represented by a continuous sequence of RDDs, which is Spark's abstraction of an immutable, -distributed dataset. Each RDD in a DStream contains data from a certain interval, +a DStream is represented by a continuous series of RDDs, which is Spark's abstraction of an immutable, +distributed dataset (see [Spark Programming Guide](programming-guide.html#resilient-distributed-datasets-rdds) for more details). Each RDD in a DStream contains data from a certain interval, as shown in the following figure.

    @@ -392,8 +433,8 @@ as shown in the following figure. Any operation applied on a DStream translates to operations on the underlying RDDs. For example, in the [earlier example](#a-quick-example) of converting a stream of lines to words, -the `flatmap` operation is applied on each RDD in the `lines` DStream to generate the RDDs of the - `words` DStream. This is shown the following figure. +the `flatMap` operation is applied on each RDD in the `lines` DStream to generate the RDDs of the + `words` DStream. This is shown in the following figure.

    -

    -{% highlight scala %} -ssc.fileStream(dataDirectory) -{% endhighlight %} -
    -
    -{% highlight java %} -jssc.fileStream(dataDirectory); -{% endhighlight %} -
    - +
    +
    + streamingContext.fileStream[keyClass, valueClass, inputFormatClass](dataDirectory) +
    +
    + streamingContext.fileStream(dataDirectory); +
    +
    + + Spark Streaming will monitor the directory `dataDirectory` and process any files created in that directory (files written in nested directories not supported). Note that + + + The files must have the same data format. + + The files must be created in the `dataDirectory` by atomically *moving* or *renaming* them into + the data directory. + + Once moved, the files must not be changed. So if the files are being continuously appended, the new data will not be read. -Spark Streaming will monitor the directory `dataDirectory` for any Hadoop-compatible filesystem -and process any files created in that directory. Note that + For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. - * The files must have the same data format. - * The files must be created in the `dataDirectory` by atomically *moving* or *renaming* them into - the data directory. - * Once moved the files must not be changed. +- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details. -For more details on streams from files, Akka actors and sockets, +- **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. + +For more details on streams from sockets, files, and actors, see the API documentations of the relevant functions in [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) for -Scala and [JavaStreamingContext](api/scala/index.html#org.apache.spark.streaming.api.java.JavaStreamingContext) - for Java. +Scala and [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) for Java. + +### Advanced Sources +{:.no_toc} +This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts of dependencies, the functionality to create DStreams from these sources have been moved to separate libraries, that can be [linked to](#linking) explicitly as necessary. For example, if you want to create a DStream using data from Twitter's stream of tweets, you have to do the following. -Additional functionality for creating DStreams from sources such as Kafka, Flume, Kinesis, and Twitter -can be imported by adding the right dependencies as explained in an -[earlier](#linking) section. To take the -case of Kafka, after adding the artifact `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` to the -project dependencies, you can create a DStream from Kafka as +1. *Linking*: Add the artifact `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` to the SBT/Maven project dependencies. +1. *Programming*: Import the `TwitterUtils` class and create a DStream with `TwitterUtils.createStream` as shown below. +1. *Deploying*: Generate an uber JAR with all the dependencies (including the dependency `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` and its transitive dependencies) and then deploy the application. This is further explained in the [Deploying section](#deploying-applications).
    {% highlight scala %} -import org.apache.spark.streaming.kafka._ -KafkaUtils.createStream(ssc, kafkaParams, ...) +import org.apache.spark.streaming.twitter._ + +TwitterUtils.createStream(ssc) {% endhighlight %}
    {% highlight java %} -import org.apache.spark.streaming.kafka.*; -KafkaUtils.createStream(jssc, kafkaParams, ...); +import org.apache.spark.streaming.twitter.*; + +TwitterUtils.createStream(jssc); {% endhighlight %}
    -For more details on these additional sources, see the corresponding [API documentation](#where-to-go-from-here). -Furthermore, you can also implement your own custom receiver for your sources. See the -[Custom Receiver Guide](streaming-custom-receivers.html). +Note that these advanced sources are not available in the `spark-shell`, hence applications based on these +advanced sources cannot be tested in the shell. + +Some of these advanced sources are as follows. + +- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using + [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information + can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by + Twitter4J library. You can either get the public stream, or get the filtered stream based on a + keywords. See the API documentation ([Scala](api/scala/index.html#org.apache.spark.streaming.twitter.TwitterUtils$), [Java](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html)) and examples ([TwitterPopularTags]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala) and + [TwitterAlgebirdCMS]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala)). + +- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can received data from Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. -### Kinesis -[Kinesis](streaming-kinesis.html) +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can receive data from Kafka 0.8.0. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. -## Operations -There are two kinds of DStream operations - _transformations_ and _output operations_. Similar to -RDD transformations, DStream transformations operate on one or more DStreams to create new DStreams -with transformed data. After applying a sequence of transformations to the input streams, output -operations need to called, which write data out to an external data sink, such as a filesystem or a -database. +- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -### Transformations -DStreams support many of the transformations available on normal Spark RDD's. Some of the -common ones are as follows. +### Custom Sources +{:.no_toc} +Input DStreams can also be created out of custom data sources. All you have to do is implement an user-defined **receiver** (see next section to understand what that is) that can receive data from the custom sources and push it into Spark. See the +[Custom Receiver Guide](streaming-custom-receivers.html) for details. + +*** + +## Transformations on DStreams +Similar to that of RDDs, transformations allow the data from the input DStream to be modified. +DStreams support many of the transformations available on normal Spark RDD's. +Some of the common ones are as follows. @@ -557,8 +633,8 @@ common ones are as follows. The last two transformations are worth highlighting again. -

    UpdateStateByKey Operation

    - +#### UpdateStateByKey Operation +{:.no_toc} The `updateStateByKey` operation allows you to maintain arbitrary state while continuously updating it with new information. To use this, you will have to do two steps. @@ -616,8 +692,8 @@ the `(word, 1)` pairs) and the `runningCount` having the previous count. For the Scala code, take a look at the example [StatefulNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala). -

    Transform Operation

    - +#### Transform Operation +{:.no_toc} The `transform` operation (along with its variations like `transformWith`) allows arbitrary RDD-to-RDD functions to be applied on a DStream. It can be used to apply any RDD operation that is not exposed in the DStream API. @@ -662,8 +738,8 @@ JavaPairDStream cleanedDStream = wordCounts.transform( In fact, you can also use [machine learning](mllib-guide.html) and [graph computation](graphx-programming-guide.html) algorithms in the `transform` method. -

    Window Operations

    - +#### Window Operations +{:.no_toc} Finally, Spark Streaming also provides *windowed computations*, which allow you to apply transformations over a sliding window of data. This following figure illustrates this sliding window. @@ -678,11 +754,11 @@ window. As shown in the figure, every time the window *slides* over a source DStream, the source RDDs that fall within the window are combined and operated upon to produce the RDDs of the windowed DStream. In this specific case, the operation is applied over last 3 time -units of data, and slides by 2 time units. This shows that any window-based operation needs to +units of data, and slides by 2 time units. This shows that any window operation needs to specify two parameters. * window length - The duration of the window (3 in the figure) - * slide interval - The interval at which the window-based operation is performed (2 in + * sliding interval - The interval at which the window operation is performed (2 in the figure). These two parameters must be multiples of the batch interval of the source DStream (1 in the @@ -720,7 +796,7 @@ JavaPairDStream windowedWordCounts = pairs.reduceByKeyAndWindow -Some of the common window-based operations are as follows. All of these operations take the +Some of the common window operations are as follows. All of these operations take the said two parameters - windowLength and slideInterval.
    TransformationMeaning
    @@ -778,21 +854,27 @@ said two parameters - windowLength and slideInterval.
    -### Output Operations -When an output operator is called, it triggers the computation of a stream. Currently the following -output operators are defined: + +The complete list of DStream transformations is available in the API documentation. For the Scala API, +see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) +and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). +For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) +and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html). + +*** + +## Output Operations on DStreams +Output operations allow DStream's data to be pushed out external systems like a database or a file systems. +Since the output operations actually allow the transformed data to be consumed by external systems, +they trigger the actual execution of all the DStream transformations (similar to actions for RDDs). +Currently, the following output operations are defined: - - - - - + @@ -811,17 +893,84 @@ output operators are defined: + + + +
    Output OperationMeaning
    print() Prints first ten elements of every batch of data in a DStream on the driver.
    foreachRDD(func) The fundamental output operator. Applies a function, func, to each RDD generated from - the stream. This function should have side effects, such as printing output, saving the RDD to - external files, or writing it over the network to an external system. Prints first ten elements of every batch of data in a DStream on the driver. + This is useful for development and debugging.
    saveAsObjectFiles(prefix, [suffix]) Save this DStream's contents as a Hadoop file. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    foreachRDD(func) The most generic output operator that applies a function, func, to each RDD generated from + the stream. This function should push the data in each RDD to a external system, like saving the RDD to + files, or writing it over the network to a database. Note that the function func is executed + at the driver, and will usually have RDD actions in it that will force the computation of the streaming RDDs.
    +### Design Patterns for using foreachRDD +{:.no_toc} +`dstream.foreachRDD` is a powerful primitive that allows data to sent out to external systems. +However, it is important to understand how to use this primitive correctly and efficiently. +Some of the common mistakes to avoid are as follows. -The complete list of DStream operations is available in the API documentation. For the Scala API, -see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) -and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). -For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) -and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html). +- Often writing data to external system requires creating a connection object +(e.g. TCP connection to a remote server) and using it to send data to a remote system. +For this purpose, a developer may inadvertantly try creating a connection object at +the Spark driver, but try to use it in a Spark worker to save records in the RDDs. +For example (in Scala), + + dstream.foreachRDD(rdd => { + val connection = createNewConnection() // executed at the driver + rdd.foreach(record => { + connection.send(record) // executed at the worker + }) + }) + + This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. + +- However, this can lead to another common mistake - creating a new connection for every record. For example, + + dstream.foreachRDD(rdd => { + rdd.foreach(record => { + val connection = createNewConnection() + connection.send(record) + connection.close() + }) + }) + + Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. + + dstream.foreachRDD(rdd => { + rdd.foreachPartition(partitionOfRecords => { + val connection = createNewConnection() + partitionOfRecords.foreach(record => connection.send(record)) + connection.close() + }) + }) + + This amortizes the connection creation overheads over many records. -## Persistence +- Finally, this can be further optimized by reusing connection objects across multiple RDDs/batches. + One can maintain a static pool of connection objects than can be reused as + RDDs of multiple batches are pushed to the external system, thus further reducing the overheads. + + dstream.foreachRDD(rdd => { + rdd.foreachPartition(partitionOfRecords => { + // ConnectionPool is a static, lazily initialized pool of connections + val connection = ConnectionPool.getConnection() + partitionOfRecords.foreach(record => connection.send(record)) + ConnectionPool.returnConnection(connection) // return to the pool for future reuse + }) + }) + + Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. + + +##### Other points to remember: +{:.no_toc} +- DStreams are executed lazily by the output operations, just like RDDs are lazily executed by RDD actions. Specifically, RDD actions inside the DStream output operations force the processing of the received data. Hence, if your application does not have any output operation, or has output operations like `dstream.foreachRDD()` without any RDD action inside them, then nothing will get executed. The system will simply receive the data and discard it. + +- By default, output operations are executed one-at-a-time. And they are executed in the order they are defined in the application. + +*** + +## Caching / Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple @@ -838,7 +987,9 @@ memory. This is further discussed in the [Performance Tuning](#memory-tuning) se information on different persistence levels can be found in [Spark Programming Guide](programming-guide.html#rdd-persistence). -## RDD Checkpointing +*** + +## Checkpointing A _stateful operation_ is one which operates over multiple batches of data. This includes all window-based operations and the `updateStateByKey` operation. Since stateful operations have a dependency on previous batches of data, they continuously accumulate metadata over time. @@ -867,10 +1018,19 @@ For DStreams that must be checkpointed (that is, DStreams created by `updateStat `reduceByKeyAndWindow` with inverse function), the checkpoint interval of the DStream is by default set to a multiple of the DStream's sliding interval such that its at least 10 seconds. -## Deployment +*** + +## Deploying Applications A Spark Streaming application is deployed on a cluster in the same way as any other Spark application. Please refer to the [deployment guide](cluster-overview.html) for more details. +Note that the applications +that use [advanced sources](#advanced-sources) (e.g. Kafka, Flume, Twitter) are also required to package the +extra artifact they link to, along with their dependencies, in the JAR that is used to deploy the application. +For example, an application using `TwitterUtils` will have to include +`spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` and all its transitive +dependencies in the application JAR. + If a running Spark Streaming application needs to be upgraded (with new application code), then there are two possible mechanism. @@ -889,7 +1049,9 @@ application left off. Note that this can be done only with input sources that su (like Kafka, and Flume) as data needs to be buffered while the previous application down and the upgraded application is not yet up. -## Monitoring +*** + +## Monitoring Applications Beyond Spark's [monitoring capabilities](monitoring.html), there are additional capabilities specific to Spark Streaming. When a StreamingContext is used, the [Spark web UI](monitoring.html#web-interfaces) shows @@ -912,22 +1074,18 @@ The progress of a Spark Streaming program can also be monitored using the which allows you to get receiver status and processing times. Note that this is a developer API and it is likely to be improved upon (i.e., more information reported) in the future. -*************************************************************************************************** +*************************************************************************************************** +*************************************************************************************************** # Performance Tuning Getting the best performance of a Spark Streaming application on a cluster requires a bit of tuning. This section explains a number of the parameters and configurations that can tuned to improve the performance of you application. At a high level, you need to consider two things: -
      -
    1. - Reducing the processing time of each batch of data by efficiently using cluster resources. -
    2. -
    3. - Setting the right batch size such that the batches of data can be processed as fast as they - are received (that is, data processing keeps up with the data ingestion). -
    4. -
    +1. Reducing the processing time of each batch of data by efficiently using cluster resources. + +2. Setting the right batch size such that the batches of data can be processed as fast as they + are received (that is, data processing keeps up with the data ingestion). ## Reducing the Processing Time of each Batch There are a number of optimizations that can be done in Spark to minimize the processing time of @@ -935,15 +1093,41 @@ each batch. These have been discussed in detail in [Tuning Guide](tuning.html). highlights some of the most important ones. ### Level of Parallelism in Data Receiving +{:.no_toc} Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to deserialized and stored in Spark. If the data receiving becomes a bottleneck in the system, then consider parallelizing the data receiving. Note that each input DStream creates a single receiver (running on a worker machine) that receives a single stream of data. Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). -For example, a single Kafka input stream receiving two topics of data can be split into two +For example, a single Kafka input DStream receiving two topics of data can be split into two Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. +thus allowing data to be received in parallel, and increasing overall throughput. These multiple +DStream can be unioned together to create a single DStream. Then the transformations that was +being applied on the single input DStream can applied on the unified stream. This is done as follows. + +
    +
    +{% highlight scala %} +val numStreams = 5 +val kafkaStreams = (1 to numStreams).map { i => KafkaUtils.createStream(...) } +val unifiedStream = streamingContext.union(kafkaStreams) +unifiedStream.print() +{% endhighlight %} +
    +
    +{% highlight java %} +int numStreams = 5; +List> kafkaStreams = new ArrayList>(numStreams); +for (int i = 0; i < numStreams; i++) { + kafkaStreams.add(KafkaUtils.createStream(...)); +} +JavaPairDStream unifiedStream = streamingContext.union(kafkaStreams.get(0), kafkaStreams.subList(1, kafkaStreams.size())); +unifiedStream.print(); +{% endhighlight %} +
    +
    + Another parameter that should be considered is the receiver's blocking interval. For most receivers, the received data is coalesced together into large blocks of data before storing inside Spark's memory. @@ -958,7 +1142,8 @@ This distributes the received batches of data across specified number of machine before further processing. ### Level of Parallelism in Data Processing -Cluster resources maybe under-utilized if the number of parallel tasks used in any stage of the +{:.no_toc} +Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is decided by the [config property] (configuration.html#spark-properties) `spark.default.parallelism`. You can pass the level of @@ -968,6 +1153,7 @@ documentation), or set the [config property](configuration.html#spark-properties `spark.default.parallelism` to change the default. ### Data Serialization +{:.no_toc} The overhead of data serialization can be significant, especially when sub-second batch sizes are to be achieved. There are two aspects to it. @@ -980,6 +1166,7 @@ The overhead of data serialization can be significant, especially when sub-secon serialization format. Hence, the deserialization overhead of input data may be a bottleneck. ### Task Launching Overheads +{:.no_toc} If the number of tasks launched per second is high (say, 50 or more per second), then the overhead of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: @@ -994,6 +1181,8 @@ latencies. The overhead can be reduced by the following changes: These changes may reduce batch processing time by 100s of milliseconds, thus allowing sub-second batch size to be viable. +*** + ## Setting the Right Batch Size For a Spark Streaming application running on a cluster to be stable, the system should be able to process data as fast as it is being received. In other words, batches of data should be processed @@ -1022,6 +1211,8 @@ data rate and/or reducing the batch size. Note that momentary increase in the de temporary data rate increases maybe fine as long as the delay reduces back to a low value (i.e., less than batch size). +*** + ## Memory Tuning Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, @@ -1037,7 +1228,7 @@ Even though keeping the data serialized incurs higher serialization/deserializat it significantly reduces GC pauses. * **Clearing persistent RDDs**: By default, all persistent RDDs generated by Spark Streaming will - be cleared from memory based on Spark's in-built policy (LRU). If `spark.cleaner.ttl` is set, + be cleared from memory based on Spark's built-in policy (LRU). If `spark.cleaner.ttl` is set, then persistent RDDs that are older than that value are periodically cleared. As mentioned [earlier](#operation), this needs to be careful set based on operations used in the Spark Streaming program. However, a smarter unpersisting of RDDs can be enabled by setting the @@ -1051,7 +1242,8 @@ minimizes the variability of GC pauses. Even though concurrent GC is known to re overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times. -*************************************************************************************************** +*************************************************************************************************** +*************************************************************************************************** # Fault-tolerance Properties In this section, we are going to discuss the behavior of Spark Streaming application in the event @@ -1124,7 +1316,7 @@ def functionToCreateContext(): StreamingContext = { ssc } -// Get StreaminContext from checkpoint data or create a new one +// Get StreamingContext from checkpoint data or create a new one val context = StreamingContext.getOrCreate(checkpointDirectory, functionToCreateContext _) // Do additional setup on context that needs to be done, @@ -1178,10 +1370,7 @@ context.awaitTermination(); If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. If the directory does not exist (i.e., running for the first time), then the function `contextFactory` will be called to create a new -context and set up the DStreams. See the Scala example -[JavaRecoverableWordCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/JavaRecoverableWordCount.scala) -(note that this example is missing in the 0.9 release, so you can test it using the master branch). -This example appends the word counts of network data into a file. +context and set up the DStreams. You can also explicitly create a `JavaStreamingContext` from the checkpoint data and start the computation by using `new JavaStreamingContext(checkpointDirectory)`. @@ -1208,7 +1397,8 @@ automatically restarted, and the word counts will cont For other deployment environments like Mesos and Yarn, you have to restart the driver through other mechanisms. -

    Recovery Semantics

    +#### Recovery Semantics +{:.no_toc} There are two different failure behaviors based on which input sources are used. @@ -1306,7 +1496,8 @@ in the file. This is what the sequence of outputs would be with and without a dr If the driver had crashed in the middle of the processing of time 3, then it will process time 3 and output 30 after recovery. -*************************************************************************************************** +*************************************************************************************************** +*************************************************************************************************** # Migration Guide from 0.9.1 or below to 1.x Between Spark 0.9.1 and Spark 1.0, there were a few API changes made to ensure future API stability. @@ -1332,7 +1523,7 @@ replaced by [Receiver](api/scala/index.html#org.apache.spark.streaming.receiver. the following advantages. * Methods like `stop` and `restart` have been added to for better control of the lifecycle of a receiver. See -the [custom receiver guide](streaming-custom-receiver.html) for more details. +the [custom receiver guide](streaming-custom-receivers.html) for more details. * Custom receivers can be implemented using both Scala and Java. To migrate your existing custom receivers from the earlier NetworkReceiver to the new Receiver, you have @@ -1356,6 +1547,7 @@ the `org.apache.spark.streaming.receivers` package were also moved to [`org.apache.spark.streaming.receiver`](api/scala/index.html#org.apache.spark.streaming.receiver.package) package and renamed for better clarity. +*************************************************************************************************** *************************************************************************************************** # Where to Go from Here @@ -1366,6 +1558,7 @@ package and renamed for better clarity. [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) * [KafkaUtils](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$), [FlumeUtils](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$), + [KinesisUtils](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$), [TwitterUtils](api/scala/index.html#org.apache.spark.streaming.twitter.TwitterUtils$), [ZeroMQUtils](api/scala/index.html#org.apache.spark.streaming.zeromq.ZeroMQUtils$), and [MQTTUtils](api/scala/index.html#org.apache.spark.streaming.mqtt.MQTTUtils$) @@ -1375,6 +1568,7 @@ package and renamed for better clarity. [PairJavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/PairJavaDStream.html) * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), [FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html), + [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) [TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html), [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 3a8c816cfffa1..bfd07593b92ed 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -26,6 +26,7 @@ import pipes import random import shutil +import string import subprocess import sys import tempfile @@ -34,9 +35,11 @@ from optparse import OptionParser from sys import stderr import boto -from boto.ec2.blockdevicemapping import BlockDeviceMapping, EBSBlockDeviceType +from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType from boto import ec2 +DEFAULT_SPARK_VERSION = "1.0.0" + # A URL prefix from which to fetch AMI information AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" @@ -56,10 +59,10 @@ def parse_args(): help="Show this help message and exit") parser.add_option( "-s", "--slaves", type="int", default=1, - help="Number of slaves to launch (default: 1)") + help="Number of slaves to launch (default: %default)") parser.add_option( "-w", "--wait", type="int", default=120, - help="Seconds to wait for nodes to start (default: 120)") + help="Seconds to wait for nodes to start (default: %default)") parser.add_option( "-k", "--key-pair", help="Key pair to use on instances") @@ -68,7 +71,7 @@ def parse_args(): help="SSH private key file to use for logging into instances") parser.add_option( "-t", "--instance-type", default="m1.large", - help="Type of instance to launch (default: m1.large). " + + help="Type of instance to launch (default: %default). " + "WARNING: must be 64-bit; small instances won't work") parser.add_option( "-m", "--master-instance-type", default="", @@ -83,15 +86,15 @@ def parse_args(): "between zones applies)") parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") parser.add_option( - "-v", "--spark-version", default="1.0.0", - help="Version of Spark to use: 'X.Y.Z' or a specific git hash") + "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, + help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") parser.add_option( "--spark-git-repo", default="https://github.com/apache/spark", help="Github repo from which to checkout supplied commit hash") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: 1)") + help="Major version of Hadoop (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -102,26 +105,34 @@ def parse_args(): "(for debugging)") parser.add_option( "--ebs-vol-size", metavar="SIZE", type="int", default=0, - help="Attach a new EBS volume of size SIZE (in GB) to each node as " + - "/vol. The volumes will be deleted when the instances terminate. " + - "Only possible on EBS-backed AMIs.") + help="Size (in GB) of each EBS volume.") + parser.add_option( + "--ebs-vol-type", default="standard", + help="EBS volume type (e.g. 'gp2', 'standard').") + parser.add_option( + "--ebs-vol-num", type="int", default=1, + help="Number of EBS volumes to attach to each node as /vol[x]. " + + "The volumes will be deleted when the instances terminate. " + + "Only possible on EBS-backed AMIs. " + + "EBS volumes are only attached if --ebs-vol-size > 0." + + "Only support up to 8 EBS volumes.") parser.add_option( "--swap", metavar="SWAP", type="int", default=1024, - help="Swap space to set up per node, in MB (default: 1024)") + help="Swap space to set up per node, in MB (default: %default)") parser.add_option( "--spot-price", metavar="PRICE", type="float", help="If specified, launch slaves as spot instances with the given " + "maximum price (in dollars)") parser.add_option( "--ganglia", action="store_true", default=True, - help="Setup Ganglia monitoring on cluster (default: on). NOTE: " + + help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " + "the Ganglia page will be publicly accessible") parser.add_option( "--no-ganglia", action="store_false", dest="ganglia", help="Disable Ganglia monitoring for the cluster") parser.add_option( "-u", "--user", default="root", - help="The SSH user you want to connect as (default: root)") + help="The SSH user you want to connect as (default: %default)") parser.add_option( "--delete-groups", action="store_true", default=False, help="When destroying a cluster, delete the security groups that were created.") @@ -130,7 +141,7 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: 1)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + @@ -141,6 +152,12 @@ def parse_args(): parser.add_option( "--security-group-prefix", type="string", default=None, help="Use this prefix for the security group rather than the cluster name.") + parser.add_option( + "--authorized-address", type="string", default="0.0.0.0/0", + help="Address to authorize on created security groups (default: %default)") + parser.add_option( + "--additional-security-group", type="string", default="", + help="Additional security group to place the machines in") (opts, args) = parser.parse_args() if len(args) != 2: @@ -228,10 +245,10 @@ def get_spark_ami(opts): "cg1.4xlarge": "hvm", "hs1.8xlarge": "pvm", "hi1.4xlarge": "pvm", - "m3.medium": "pvm", - "m3.large": "pvm", - "m3.xlarge": "pvm", - "m3.2xlarge": "pvm", + "m3.medium": "hvm", + "m3.large": "hvm", + "m3.xlarge": "hvm", + "m3.2xlarge": "hvm", "cr1.8xlarge": "hvm", "i2.xlarge": "hvm", "i2.2xlarge": "hvm", @@ -293,28 +310,29 @@ def launch_cluster(conn, opts, cluster_name): else: master_group = get_or_make_group(conn, opts.security_group_prefix + "-master") slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves") + authorized_address = opts.authorized_address if master_group.rules == []: # Group was just now created master_group.authorize(src_group=master_group) master_group.authorize(src_group=slave_group) - master_group.authorize('tcp', 22, 22, '0.0.0.0/0') - master_group.authorize('tcp', 8080, 8081, '0.0.0.0/0') - master_group.authorize('tcp', 18080, 18080, '0.0.0.0/0') - master_group.authorize('tcp', 19999, 19999, '0.0.0.0/0') - master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0') - master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0') - master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0') - master_group.authorize('tcp', 4040, 4045, '0.0.0.0/0') + master_group.authorize('tcp', 22, 22, authorized_address) + master_group.authorize('tcp', 8080, 8081, authorized_address) + master_group.authorize('tcp', 18080, 18080, authorized_address) + master_group.authorize('tcp', 19999, 19999, authorized_address) + master_group.authorize('tcp', 50030, 50030, authorized_address) + master_group.authorize('tcp', 50070, 50070, authorized_address) + master_group.authorize('tcp', 60070, 60070, authorized_address) + master_group.authorize('tcp', 4040, 4045, authorized_address) if opts.ganglia: - master_group.authorize('tcp', 5080, 5080, '0.0.0.0/0') + master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created slave_group.authorize(src_group=master_group) slave_group.authorize(src_group=slave_group) - slave_group.authorize('tcp', 22, 22, '0.0.0.0/0') - slave_group.authorize('tcp', 8080, 8081, '0.0.0.0/0') - slave_group.authorize('tcp', 50060, 50060, '0.0.0.0/0') - slave_group.authorize('tcp', 50075, 50075, '0.0.0.0/0') - slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0') - slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0') + slave_group.authorize('tcp', 22, 22, authorized_address) + slave_group.authorize('tcp', 8080, 8081, authorized_address) + slave_group.authorize('tcp', 50060, 50060, authorized_address) + slave_group.authorize('tcp', 50075, 50075, authorized_address) + slave_group.authorize('tcp', 60060, 60060, authorized_address) + slave_group.authorize('tcp', 60075, 60075, authorized_address) # Check if instances are already running with the cluster name existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, @@ -326,6 +344,12 @@ def launch_cluster(conn, opts, cluster_name): # Figure out Spark AMI if opts.ami is None: opts.ami = get_spark_ami(opts) + + additional_groups = [] + if opts.additional_security_group: + additional_groups = [sg + for sg in conn.get_all_security_groups() + if opts.additional_security_group in (sg.name, sg.id)] print "Launching instances..." try: @@ -334,13 +358,25 @@ def launch_cluster(conn, opts, cluster_name): print >> stderr, "Could not find AMI " + opts.ami sys.exit(1) - # Create block device mapping so that we can add an EBS volume if asked to + # Create block device mapping so that we can add EBS volumes if asked to. + # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz block_map = BlockDeviceMapping() if opts.ebs_vol_size > 0: - device = EBSBlockDeviceType() - device.size = opts.ebs_vol_size - device.delete_on_termination = True - block_map["/dev/sdv"] = device + for i in range(opts.ebs_vol_num): + device = EBSBlockDeviceType() + device.size = opts.ebs_vol_size + device.volume_type = opts.ebs_vol_type + device.delete_on_termination = True + block_map["/dev/sd" + chr(ord('s') + i)] = device + + # AWS ignores the AMI-specified block device mapping for M3 (see SPARK-3342). + if opts.instance_type.startswith('m3.'): + for i in range(get_num_disks(opts.instance_type)): + dev = BlockDeviceType() + dev.ephemeral_name = 'ephemeral%d' % i + # The first ephemeral drive is /dev/sdb. + name = '/dev/sd' + string.letters[i + 1] + block_map[name] = dev # Launch slaves if opts.spot_price is not None: @@ -360,7 +396,7 @@ def launch_cluster(conn, opts, cluster_name): placement=zone, count=num_slaves_this_zone, key_name=opts.key_pair, - security_groups=[slave_group], + security_groups=[slave_group] + additional_groups, instance_type=opts.instance_type, block_device_map=block_map, user_data=user_data_content) @@ -413,7 +449,7 @@ def launch_cluster(conn, opts, cluster_name): num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) if num_slaves_this_zone > 0: slave_res = image.run(key_name=opts.key_pair, - security_groups=[slave_group], + security_groups=[slave_group] + additional_groups, instance_type=opts.instance_type, placement=zone, min_count=num_slaves_this_zone, @@ -439,48 +475,60 @@ def launch_cluster(conn, opts, cluster_name): if opts.zone == 'all': opts.zone = random.choice(conn.get_all_zones()).name master_res = image.run(key_name=opts.key_pair, - security_groups=[master_group], + security_groups=[master_group] + additional_groups, instance_type=master_type, placement=opts.zone, min_count=1, max_count=1, - block_device_map=block_map) + block_device_map=block_map, + user_data=user_data_content) master_nodes = master_res.instances print "Launched master in %s, regid = %s" % (zone, master_res.id) # Give the instances descriptive names - # TODO: Add retry logic for tagging with name since it's used to identify a cluster. for master in master_nodes: name = '{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id) - for i in range(0, 5): - try: - master.add_tag(key='Name', value=name) - except: - print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) - if (i == 5): - raise "Error - failed max attempts to add name tag" - time.sleep(5) - + tag_instance(master, name) for slave in slave_nodes: name = '{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id) - for i in range(0, 5): - try: - slave.add_tag(key='Name', value=name) - except: - print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) - if (i == 5): - raise "Error - failed max attempts to add name tag" - time.sleep(5) + tag_instance(slave, name) # Return all the instances return (master_nodes, slave_nodes) +def tag_instance(instance, name): + for i in range(0, 5): + try: + instance.add_tag(key='Name', value=name) + except: + print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) + if (i == 5): + raise "Error - failed max attempts to add name tag" + time.sleep(5) + # Get the EC2 instances in an existing cluster if available. # Returns a tuple of lists of EC2 instance objects for the masters and slaves + + def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): print "Searching for existing cluster " + cluster_name + "..." + # Search all the spot instance requests, and copy any tags from the spot + # instance request to the cluster. + spot_instance_requests = conn.get_all_spot_instance_requests() + for req in spot_instance_requests: + if req.state != u'active': + continue + name = req.tags.get(u'Name', "") + if name.startswith(cluster_name): + reservations = conn.get_all_instances(instance_ids=[req.instance_id]) + for res in reservations: + active = [i for i in res.instances if is_active(i)] + for instance in active: + if (instance.tags.get(u'Name') is None): + tag_instance(instance, name) + # Now proceed to detect master and slaves instances. reservations = conn.get_all_instances() master_nodes = [] slave_nodes = [] @@ -498,14 +546,16 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in with name " + cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in with name " + \ + cluster_name + "-master" else: print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) - # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. + + def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): master = master_nodes[0].public_dns_name if deploy_ssh_key: @@ -798,6 +848,12 @@ def get_partition(total, num_partitions, current_partitions): def real_main(): (opts, action, cluster_name) = parse_args() + + # Input parameter validation + if opts.ebs_vol_num > 8: + print >> stderr, "ebs-vol-num cannot be greater than 8" + sys.exit(1) + try: conn = ec2.connect_to_region(opts.region) except Exception as e: @@ -843,7 +899,8 @@ def real_main(): if opts.security_group_prefix is None: group_names = [cluster_name + "-master", cluster_name + "-slaves"] else: - group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] + group_names = [opts.security_group_prefix + "-master", + opts.security_group_prefix + "-slaves"] attempt = 1 while attempt <= 3: diff --git a/examples/pom.xml b/examples/pom.xml index 8c4c128bb484d..3f46c40464d3b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml @@ -46,8 +46,14 @@
    - + + + + com.google.guava + guava + compile + org.apache.spark spark-core_${scala.binary.version} @@ -209,6 +215,12 @@ + + com.google.guava:guava + + com/google/common/base/Optional* + + *:* @@ -226,6 +238,18 @@ shade + + + com.google + org.spark-project.guava + + com.google.common.** + + + com.google.common.base.Optional** + + + diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index 7ea6df9c17245..c22506491fbff 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -96,7 +96,7 @@ public Double call(Iterable rs) { .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { @Override public Iterable> call(Tuple2, Double> s) { - int urlCount = Iterables.size(s._1); + int urlCount = Iterables.size(s._1); List> results = new ArrayList>(); for (String n : s._1) { results.add(new Tuple2(n, s._2() / urlCount)); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java new file mode 100644 index 0000000000000..1f82e3f4cb18e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.HashMap; + +import scala.Tuple2; + +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; + +/** + * Classification and regression using decision trees. + */ +public final class JavaDecisionTree { + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + if (args.length == 1) { + datapath = args[0]; + } else if (args.length > 1) { + System.err.println("Usage: JavaDecisionTree "); + System.exit(1); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + HashMap categoricalFeaturesInfo = new HashMap(); + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model for classification. + final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + + // Train a DecisionTree model for regression. + impurity = "variance"; + final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on training instances and compute training error + JavaPairRDD regressorPredictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(regressionModel.predict(p.features()), p.label()); + } + }); + Double trainMSE = + regressorPredictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + regressionModel); + + sc.stop(); + } +} diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index e902ae29753c0..cfda8d8327aa3 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -23,7 +23,8 @@ Read data file users.avro in local Spark distro: $ cd $SPARK_HOME -$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +$ ./bin/spark-submit --driver-class-path /path/to/example/jar \ +> ./examples/src/main/python/avro_inputformat.py \ > examples/src/main/resources/users.avro {u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]} {u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} @@ -40,7 +41,8 @@ ] } -$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +$ ./bin/spark-submit --driver-class-path /path/to/example/jar \ +> ./examples/src/main/python/avro_inputformat.py \ > examples/src/main/resources/users.avro examples/src/main/resources/user.avsc {u'favorite_color': None, u'name': u'Alyssa'} {u'favorite_color': u'red', u'name': u'Ben'} @@ -51,8 +53,10 @@ Usage: avro_inputformat [reader_schema_file] Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/avro_inputformat.py [reader_schema_file] - Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/avro_inputformat.py [reader_schema_file] + Assumes you have Avro data stored in . Reader schema can be optionally specified + in [reader_schema_file]. """ exit(-1) @@ -62,9 +66,10 @@ conf = None if len(sys.argv) == 3: schema_rdd = sc.textFile(sys.argv[2], 1).collect() - conf = {"avro.schema.input.key" : reduce(lambda x, y: x+y, schema_rdd)} + conf = {"avro.schema.input.key": reduce(lambda x, y: x + y, schema_rdd)} - avro_rdd = sc.newAPIHadoopFile(path, + avro_rdd = sc.newAPIHadoopFile( + path, "org.apache.avro.mapreduce.AvroKeyInputFormat", "org.apache.avro.mapred.AvroKey", "org.apache.hadoop.io.NullWritable", diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index e4a897f61e39d..05f34b74df45a 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -51,7 +51,8 @@ Usage: cassandra_inputformat Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_inputformat.py + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/cassandra_inputformat.py Assumes you have some data in Cassandra already, running on , in and """ exit(-1) @@ -61,12 +62,12 @@ cf = sys.argv[3] sc = SparkContext(appName="CassandraInputFormat") - conf = {"cassandra.input.thrift.address":host, - "cassandra.input.thrift.port":"9160", - "cassandra.input.keyspace":keyspace, - "cassandra.input.columnfamily":cf, - "cassandra.input.partitioner.class":"Murmur3Partitioner", - "cassandra.input.page.row.size":"3"} + conf = {"cassandra.input.thrift.address": host, + "cassandra.input.thrift.port": "9160", + "cassandra.input.keyspace": keyspace, + "cassandra.input.columnfamily": cf, + "cassandra.input.partitioner.class": "Murmur3Partitioner", + "cassandra.input.page.row.size": "3"} cass_rdd = sc.newAPIHadoopRDD( "org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat", "java.util.Map", diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index 836c35b5c6794..d144539e58b8f 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -50,7 +50,8 @@ Usage: cassandra_outputformat Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_outputformat.py + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/cassandra_outputformat.py Assumes you have created the following table in Cassandra already, running on , in . @@ -67,16 +68,16 @@ cf = sys.argv[3] sc = SparkContext(appName="CassandraOutputFormat") - conf = {"cassandra.output.thrift.address":host, - "cassandra.output.thrift.port":"9160", - "cassandra.output.keyspace":keyspace, - "cassandra.output.partitioner.class":"Murmur3Partitioner", - "cassandra.output.cql":"UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?", - "mapreduce.output.basename":cf, - "mapreduce.outputformat.class":"org.apache.cassandra.hadoop.cql3.CqlOutputFormat", - "mapreduce.job.output.key.class":"java.util.Map", - "mapreduce.job.output.value.class":"java.util.List"} - key = {"user_id" : int(sys.argv[4])} + conf = {"cassandra.output.thrift.address": host, + "cassandra.output.thrift.port": "9160", + "cassandra.output.keyspace": keyspace, + "cassandra.output.partitioner.class": "Murmur3Partitioner", + "cassandra.output.cql": "UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?", + "mapreduce.output.basename": cf, + "mapreduce.outputformat.class": "org.apache.cassandra.hadoop.cql3.CqlOutputFormat", + "mapreduce.job.output.key.class": "java.util.Map", + "mapreduce.job.output.value.class": "java.util.List"} + key = {"user_id": int(sys.argv[4])} sc.parallelize([(key, sys.argv[5:])]).saveAsNewAPIHadoopDataset( conf=conf, keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter", diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index befacee0dea56..3b16010f1cb97 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -51,7 +51,8 @@ Usage: hbase_inputformat Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_inputformat.py
    + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/hbase_inputformat.py
    Assumes you have some data in HBase already, running on , in
    """ exit(-1) @@ -61,12 +62,15 @@ sc = SparkContext(appName="HBaseInputFormat") conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} + keyConv = "org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter" + valueConv = "org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter" + hbase_rdd = sc.newAPIHadoopRDD( "org.apache.hadoop.hbase.mapreduce.TableInputFormat", "org.apache.hadoop.hbase.io.ImmutableBytesWritable", "org.apache.hadoop.hbase.client.Result", - keyConverter="org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter", - valueConverter="org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter", + keyConverter=keyConv, + valueConverter=valueConv, conf=conf) output = hbase_rdd.collect() for (k, v) in output: diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index 49bbc5aebdb0b..abb425b1f886a 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -44,8 +44,10 @@ Usage: hbase_outputformat
    Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_outputformat.py - Assumes you have created
    with column family in HBase running on already + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/hbase_outputformat.py + Assumes you have created
    with column family in HBase + running on already """ exit(-1) @@ -55,13 +57,15 @@ conf = {"hbase.zookeeper.quorum": host, "hbase.mapred.outputtable": table, - "mapreduce.outputformat.class" : "org.apache.hadoop.hbase.mapreduce.TableOutputFormat", - "mapreduce.job.output.key.class" : "org.apache.hadoop.hbase.io.ImmutableBytesWritable", - "mapreduce.job.output.value.class" : "org.apache.hadoop.io.Writable"} + "mapreduce.outputformat.class": "org.apache.hadoop.hbase.mapreduce.TableOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.hbase.io.ImmutableBytesWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.Writable"} + keyConv = "org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter" + valueConv = "org.apache.spark.examples.pythonconverters.StringListToPutConverter" sc.parallelize([sys.argv[3:]]).map(lambda x: (x[0], x)).saveAsNewAPIHadoopDataset( conf=conf, - keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter", - valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter") + keyConverter=keyConv, + valueConverter=valueConv) sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 6b16a56e44af7..4218eca822a99 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -28,7 +28,7 @@ if __name__ == "__main__": - if len(sys.argv) not in [1,2]: + if len(sys.argv) not in [1, 2]: print >> sys.stderr, "Usage: correlations ()" exit(-1) sc = SparkContext(appName="PythonCorrelations") diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index 6e4a4a0cb6be0..61ea4e06ecf3a 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -21,7 +21,9 @@ This example requires NumPy (http://www.numpy.org/). """ -import numpy, os, sys +import numpy +import os +import sys from operator import add @@ -127,7 +129,7 @@ def usage(): (reindexedData, origToNewLabels) = reindexClassLabels(points) # Train a classifier. - categoricalFeaturesInfo={} # no categorical features + categoricalFeaturesInfo = {} # no categorical features model = DecisionTree.trainClassifier(reindexedData, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index b388d8d83fb86..1e8892741e714 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -32,8 +32,8 @@ sc = SparkContext(appName="PythonRandomRDDGeneration") - numExamples = 10000 # number of examples to generate - fraction = 0.1 # fraction of data to sample + numExamples = 10000 # number of examples to generate + fraction = 0.1 # fraction of data to sample # Example: RandomRDDs.normalRDD normalRDD = RandomRDDs.normalRDD(sc, numExamples) @@ -45,7 +45,7 @@ print # Example: RandomRDDs.normalVectorRDD - normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2) + normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2) print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() print ' First 5 samples:' for sample in normalVectorRDD.take(5): diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index ec64a5978c672..92af3af5ebd1e 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -36,7 +36,7 @@ sc = SparkContext(appName="PythonSampledRDDs") - fraction = 0.1 # fraction of data to sample + fraction = 0.1 # fraction of data to sample examples = MLUtils.loadLibSVMFile(sc, datapath) numExamples = examples.count() @@ -49,9 +49,9 @@ expectedSampleSize = int(numExamples * fraction) print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ % (fraction, expectedSampleSize) - sampledRDD = examples.sample(withReplacement = True, fraction = fraction) + sampledRDD = examples.sample(withReplacement=True, fraction=fraction) print ' RDD.sample(): sample has %d examples' % sampledRDD.count() - sampledArray = examples.takeSample(withReplacement = True, num = expectedSampleSize) + sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize) print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) print @@ -66,7 +66,7 @@ fractions = {} for k in keyCountsA.keys(): fractions[k] = fraction - sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = True, fractions = fractions) + sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions) keyCountsB = sampledByKeyRDD.countByKey() sizeB = sum(keyCountsB.values()) print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index fc37459dc74aa..ee9036adfa281 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -35,7 +35,7 @@ def f(_): y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + count = sc.parallelize(xrange(1, n + 1), slices).map(f).reduce(add) print "Pi is roughly %f" % (4.0 * count / n) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index d583cf421ed23..3258510894372 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -22,9 +22,9 @@ import java.util.Random import scala.math.exp import breeze.linalg.{Vector, DenseVector} +import org.apache.hadoop.conf.Configuration import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.InputFormatInfo @@ -70,7 +70,7 @@ object SparkHdfsLR { val sparkConf = new SparkConf().setAppName("SparkHdfsLR") val inputPath = args(0) - val conf = SparkHadoopUtil.get.newConfiguration() + val conf = new Configuration() val sc = new SparkContext(sparkConf, InputFormatInfo.computePreferredLocations( Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 22127621867e1..96d13612e46dd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -22,9 +22,9 @@ import java.util.Random import scala.math.exp import breeze.linalg.{Vector, DenseVector} +import org.apache.hadoop.conf.Configuration import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.InputFormatInfo import org.apache.spark.storage.StorageLevel @@ -52,8 +52,8 @@ object SparkTachyonHdfsLR { def main(args: Array[String]) { val inputPath = args(0) - val conf = SparkHadoopUtil.get.newConfiguration() val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") + val conf = new Configuration() val sc = new SparkContext(sparkConf, InputFormatInfo.computePreferredLocations( Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala similarity index 98% rename from graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala rename to examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index c1513a00453cf..c4317a6aec798 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.graphx.lib +package org.apache.spark.examples.graphx import scala.collection.mutable import org.apache.spark._ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ +import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.PartitionStrategy._ /** diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index 6ef3b62dcbedc..e809a65b79975 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ import org.apache.spark._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib.Analytics + /** * Uses GraphX to run PageRank on a LiveJournal social network graph. Download the dataset from diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 551c339b19523..5f35a5836462e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -38,12 +38,13 @@ object SynthBenchmark { * Options: * -app "pagerank" or "cc" for pagerank or connected components. (Default: pagerank) * -niters the number of iterations of pagerank to use (Default: 10) - * -numVertices the number of vertices in the graph (Default: 1000000) + * -nverts the number of vertices in the graph (Default: 1000000) * -numEPart the number of edge partitions in the graph (Default: number of cores) * -partStrategy the graph partitioning strategy to use * -mu the mean parameter for the log-normal graph (Default: 4.0) * -sigma the stdev parameter for the log-normal graph (Default: 1.3) * -degFile the local file to save the degree information (Default: Empty) + * -seed seed to use for RNGs (Default: -1, picks seed randomly) */ def main(args: Array[String]) { val options = args.map { @@ -62,6 +63,7 @@ object SynthBenchmark { var mu: Double = 4.0 var sigma: Double = 1.3 var degFile: String = "" + var seed: Int = -1 options.foreach { case ("app", v) => app = v @@ -72,6 +74,7 @@ object SynthBenchmark { case ("mu", v) => mu = v.toDouble case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v + case ("seed", v) => seed = v.toInt case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) } @@ -85,7 +88,7 @@ object SynthBenchmark { // Create the graph println(s"Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, - numEPart.getOrElse(sc.defaultParallelism), mu, sigma) + numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)).cache() @@ -113,7 +116,7 @@ object SynthBenchmark { println(s"Total PageRank = $totalPR") } else if (app == "cc") { println("Running Connected Components") - val numComponents = graph.connectedComponents.vertices.map(_._2).distinct() + val numComponents = graph.connectedComponents.vertices.map(_._2).distinct().count() println(s"Number of components = $numComponents") } val runTime = System.currentTimeMillis() - startTime diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index cf3d2cca81ff6..72c3ab475b61f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -52,9 +52,9 @@ object DecisionTreeRunner { input: String = null, dataFormat: String = "libsvm", algo: Algo = Classification, - maxDepth: Int = 4, + maxDepth: Int = 5, impurity: ImpurityType = Gini, - maxBins: Int = 100, + maxBins: Int = 32, fracTest: Double = 0.2) def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index f96bc1bf00b92..89dfa26c2299c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.linalg.Vectors /** * An example k-means app. Run with * {{{ - * ./bin/spark-example org.apache.spark.examples.mllib.DenseKMeans [options] + * ./bin/run-example org.apache.spark.examples.mllib.DenseKMeans [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 88acd9dbb0878..952fa2a5109a4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.util.MLUtils /** * An example naive Bayes app. Run with * {{{ - * ./bin/spark-example org.apache.spark.examples.mllib.SparseNaiveBayes [options] + * ./bin/run-example org.apache.spark.examples.mllib.SparseNaiveBayes [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index daa1ced63c701..a4d159bf38377 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -44,7 +44,7 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.foldLeft(0)(_ + _) + val currentCount = values.sum val previousCount = state.getOrElse(0) diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0c68defa5e101..ac291bd4fde20 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,23 +21,24 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml + org.apache.spark spark-streaming-flume-sink_2.10 streaming-flume-sink - jar Spark Project External Flume Sink http://spark.apache.org/ + org.apache.flume flume-ng-sdk - 1.4.0 + ${flume.version} io.netty @@ -52,7 +53,7 @@ org.apache.flume flume-ng-core - 1.4.0 + ${flume.version} io.netty @@ -64,20 +65,26 @@ - - org.scala-lang - scala-library - org.scalatest scalatest_${scala.binary.version} + test + + + org.scala-lang + scala-library - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + + io.netty + netty + 3.4.0.Final + test @@ -91,7 +98,7 @@ org.apache.avro avro-maven-plugin - 1.7.3 + ${avro.version} ${project.basedir}/target/scala-${scala.binary.version}/src_managed/main/compiled_avro diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 7da8eb3e35912..e77cf7bfa54d0 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -19,6 +19,8 @@ package org.apache.spark.streaming.flume.sink import java.util.concurrent.{ConcurrentHashMap, Executors} import java.util.concurrent.atomic.AtomicLong +import scala.collection.JavaConversions._ + import org.apache.flume.Channel import org.apache.commons.lang.RandomStringUtils import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -45,7 +47,8 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, new ThreadFactoryBuilder().setDaemon(true) .setNameFormat("Spark Sink Processor Thread - %d").build())) - private val processorMap = new ConcurrentHashMap[CharSequence, TransactionProcessor]() + private val sequenceNumberToProcessor = + new ConcurrentHashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. // So it is possible to commit a transaction which may have been meant for the sink before the // restart. @@ -55,6 +58,8 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha private val seqBase = RandomStringUtils.randomAlphanumeric(8) private val seqCounter = new AtomicLong(0) + @volatile private var stopped = false + /** * Returns a bunch of events to Spark over Avro RPC. * @param n Maximum number of events to return in a batch @@ -63,18 +68,33 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha override def getEventBatch(n: Int): EventBatch = { logDebug("Got getEventBatch call from Spark.") val sequenceNumber = seqBase + seqCounter.incrementAndGet() - val processor = new TransactionProcessor(channel, sequenceNumber, - n, transactionTimeout, backOffInterval, this) - transactionExecutorOpt.foreach(executor => { - executor.submit(processor) - }) - // Wait until a batch is available - will be an error if error message is non-empty - val batch = processor.getEventBatch - if (!SparkSinkUtils.isErrorBatch(batch)) { - processorMap.put(sequenceNumber.toString, processor) - logDebug("Sending event batch with sequence number: " + sequenceNumber) + createProcessor(sequenceNumber, n) match { + case Some(processor) => + transactionExecutorOpt.foreach(_.submit(processor)) + // Wait until a batch is available - will be an error if error message is non-empty + val batch = processor.getEventBatch + if (SparkSinkUtils.isErrorBatch(batch)) { + // Remove the processor if it is an error batch since no ACK is sent. + removeAndGetProcessor(sequenceNumber) + logWarning("Received an error batch - no events were received from channel! ") + } + batch + case None => + new EventBatch("Spark sink has been stopped!", "", java.util.Collections.emptyList()) + } + } + + private def createProcessor(seq: String, n: Int): Option[TransactionProcessor] = { + sequenceNumberToProcessor.synchronized { + if (!stopped) { + val processor = new TransactionProcessor( + channel, seq, n, transactionTimeout, backOffInterval, this) + sequenceNumberToProcessor.put(seq, processor) + Some(processor) + } else { + None + } } - batch } /** @@ -116,7 +136,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha * longer tracked and the caller is responsible for that txn processor. */ private[sink] def removeAndGetProcessor(sequenceNumber: CharSequence): TransactionProcessor = { - processorMap.remove(sequenceNumber.toString) // The toString is required! + sequenceNumberToProcessor.synchronized { + sequenceNumberToProcessor.remove(sequenceNumber.toString) + } } /** @@ -124,8 +146,10 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha */ def shutdown() { logInfo("Shutting down Spark Avro Callback Handler") - transactionExecutorOpt.foreach(executor => { - executor.shutdownNow() - }) + sequenceNumberToProcessor.synchronized { + stopped = true + sequenceNumberToProcessor.values().foreach(_.shutdown()) + } + transactionExecutorOpt.foreach(_.shutdownNow()) } } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index b9e3c786ebb3b..13f3aa94be414 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -60,6 +60,8 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // succeeded. @volatile private var batchSuccess = false + @volatile private var stopped = false + // The transaction that this processor would handle var txOpt: Option[Transaction] = None @@ -88,6 +90,11 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, batchAckLatch.countDown() } + private[flume] def shutdown(): Unit = { + logDebug("Shutting down transaction processor") + stopped = true + } + /** * Populates events into the event batch. If the batch cannot be populated, * this method will not set the events into the event batch, but it sets an error message. @@ -106,7 +113,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, var gotEventsInThisTxn = false var loopCounter: Int = 0 loop.breakable { - while (events.size() < maxBatchSize + while (!stopped && events.size() < maxBatchSize && loopCounter < totalAttemptsToRemoveFromChannel) { loopCounter += 1 Option(channel.take()) match { @@ -115,7 +122,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, ByteBuffer.wrap(event.getBody))) gotEventsInThisTxn = true case None => - if (!gotEventsInThisTxn) { + if (!gotEventsInThisTxn && !stopped) { logDebug("Sleeping for " + backOffInterval + " millis as no events were read in" + " the current transaction") TimeUnit.MILLISECONDS.sleep(backOffInterval) @@ -125,7 +132,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, } } } - if (!gotEventsInThisTxn) { + if (!gotEventsInThisTxn && !stopped) { val msg = "Tried several times, " + "but did not get any events from the channel!" logWarning(msg) @@ -136,6 +143,11 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, } }) } catch { + case interrupted: InterruptedException => + // Don't pollute logs if the InterruptedException came from this being stopped + if (!stopped) { + logWarning("Error while processing transaction.", interrupted) + } case e: Exception => logWarning("Error while processing transaction.", e) eventBatch.setErrorMsg(e.getMessage) diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 44b27edf85ce8..75a6668c6210b 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -30,14 +30,14 @@ import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder -import org.apache.spark.streaming.TestSuiteBase import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.scalatest.FunSuite -class SparkSinkSuite extends TestSuiteBase { +class SparkSinkSuite extends FunSuite { val eventsPerBatch = 1000 val channelCapacity = 5000 - test("Success") { + test("Success with ack") { val (channel, sink) = initializeChannelAndSink() channel.start() sink.start() @@ -57,7 +57,7 @@ class SparkSinkSuite extends TestSuiteBase { transceiver.close() } - test("Nack") { + test("Failure with nack") { val (channel, sink) = initializeChannelAndSink() channel.start() sink.start() @@ -76,7 +76,7 @@ class SparkSinkSuite extends TestSuiteBase { transceiver.close() } - test("Timeout") { + test("Failure with timeout") { val (channel, sink) = initializeChannelAndSink(Map(SparkSinkConfig .CONF_TRANSACTION_TIMEOUT -> 1.toString)) channel.start() diff --git a/external/flume/pom.xml b/external/flume/pom.xml index c532705f3950c..7d31e32283d88 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml @@ -40,6 +40,11 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-streaming-flume-sink_${scala.binary.version} + ${project.version} + org.apache.spark spark-streaming_${scala.binary.version} @@ -50,7 +55,7 @@ org.apache.flume flume-ng-sdk - 1.4.0 + ${flume.version} io.netty @@ -82,11 +87,6 @@ junit-interface test - - org.apache.spark - spark-streaming-flume-sink_2.10 - ${project.version} - target/scala-${scala.binary.version}/classes diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala new file mode 100644 index 0000000000000..88cc2aa3bf022 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Throwables + +import org.apache.spark.Logging +import org.apache.spark.streaming.flume.sink._ + +/** + * This class implements the core functionality of [[FlumePollingReceiver]]. When started it + * pulls data from Flume, stores it to Spark and then sends an Ack or Nack. This class should be + * run via an [[java.util.concurrent.Executor]] as this implements [[Runnable]] + * + * @param receiver The receiver that owns this instance. + */ + +private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends Runnable with + Logging { + + def run(): Unit = { + while (!receiver.isStopped()) { + val connection = receiver.getConnections.poll() + val client = connection.client + var batchReceived = false + var seq: CharSequence = null + try { + getBatch(client) match { + case Some(eventBatch) => + batchReceived = true + seq = eventBatch.getSequenceNumber + val events = toSparkFlumeEvents(eventBatch.getEvents) + if (store(events)) { + sendAck(client, seq) + } else { + sendNack(batchReceived, client, seq) + } + case None => + } + } catch { + case e: Exception => + Throwables.getRootCause(e) match { + // If the cause was an InterruptedException, then check if the receiver is stopped - + // if yes, just break out of the loop. Else send a Nack and log a warning. + // In the unlikely case, the cause was not an Exception, + // then just throw it out and exit. + case interrupted: InterruptedException => + if (!receiver.isStopped()) { + logWarning("Interrupted while receiving data from Flume", interrupted) + sendNack(batchReceived, client, seq) + } + case exception: Exception => + logWarning("Error while receiving data from Flume", exception) + sendNack(batchReceived, client, seq) + } + } finally { + receiver.getConnections.add(connection) + } + } + } + + /** + * Gets a batch of events from the specified client. This method does not handle any exceptions + * which will be propogated to the caller. + * @param client Client to get events from + * @return [[Some]] which contains the event batch if Flume sent any events back, else [[None]] + */ + private def getBatch(client: SparkFlumeProtocol.Callback): Option[EventBatch] = { + val eventBatch = client.getEventBatch(receiver.getMaxBatchSize) + if (!SparkSinkUtils.isErrorBatch(eventBatch)) { + // No error, proceed with processing data + logDebug(s"Received batch of ${eventBatch.getEvents.size} events with sequence " + + s"number: ${eventBatch.getSequenceNumber}") + Some(eventBatch) + } else { + logWarning("Did not receive events from Flume agent due to error on the Flume agent: " + + eventBatch.getErrorMsg) + None + } + } + + /** + * Store the events in the buffer to Spark. This method will not propogate any exceptions, + * but will propogate any other errors. + * @param buffer The buffer to store + * @return true if the data was stored without any exception being thrown, else false + */ + private def store(buffer: ArrayBuffer[SparkFlumeEvent]): Boolean = { + try { + receiver.store(buffer) + true + } catch { + case e: Exception => + logWarning("Error while attempting to store data received from Flume", e) + false + } + } + + /** + * Send an ack to the client for the sequence number. This method does not handle any exceptions + * which will be propagated to the caller. + * @param client client to send the ack to + * @param seq sequence number of the batch to be ack-ed. + * @return + */ + private def sendAck(client: SparkFlumeProtocol.Callback, seq: CharSequence): Unit = { + logDebug("Sending ack for sequence number: " + seq) + client.ack(seq) + logDebug("Ack sent for sequence number: " + seq) + } + + /** + * This method sends a Nack if a batch was received to the client with the given sequence + * number. Any exceptions thrown by the RPC call is simply thrown out as is - no effort is made + * to handle it. + * @param batchReceived true if a batch was received. If this is false, no nack is sent + * @param client The client to which the nack should be sent + * @param seq The sequence number of the batch that is being nack-ed. + */ + private def sendNack(batchReceived: Boolean, client: SparkFlumeProtocol.Callback, + seq: CharSequence): Unit = { + if (batchReceived) { + // Let Flume know that the events need to be pushed back into the channel. + logDebug("Sending nack for sequence number: " + seq) + client.nack(seq) // If the agent is down, even this could fail and throw + logDebug("Nack sent for sequence number: " + seq) + } + } + + /** + * Utility method to convert [[SparkSinkEvent]]s to [[SparkFlumeEvent]]s + * @param events - Events to convert to SparkFlumeEvents + * @return - The SparkFlumeEvent generated from SparkSinkEvent + */ + private def toSparkFlumeEvents(events: java.util.List[SparkSinkEvent]): + ArrayBuffer[SparkFlumeEvent] = { + // Convert each Flume event to a serializable SparkFlumeEvent + val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) + var j = 0 + while (j < events.size()) { + val event = events(j) + val sparkFlumeEvent = new SparkFlumeEvent() + sparkFlumeEvent.event.setBody(event.getBody) + sparkFlumeEvent.event.setHeaders(event.getHeaders) + buffer += sparkFlumeEvent + j += 1 + } + buffer + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 148262bb6771e..92fa5b41be89e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -18,10 +18,9 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent.{LinkedBlockingQueue, TimeUnit, Executors} +import java.util.concurrent.{LinkedBlockingQueue, Executors} import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -86,61 +85,9 @@ private[streaming] class FlumePollingReceiver( connections.add(new FlumeConnection(transceiver, client)) }) for (i <- 0 until parallelism) { - logInfo("Starting Flume Polling Receiver worker threads starting..") + logInfo("Starting Flume Polling Receiver worker threads..") // Threads that pull data from Flume. - receiverExecutor.submit(new Runnable { - override def run(): Unit = { - while (true) { - val connection = connections.poll() - val client = connection.client - try { - val eventBatch = client.getEventBatch(maxBatchSize) - if (!SparkSinkUtils.isErrorBatch(eventBatch)) { - // No error, proceed with processing data - val seq = eventBatch.getSequenceNumber - val events: java.util.List[SparkSinkEvent] = eventBatch.getEvents - logDebug( - "Received batch of " + events.size() + " events with sequence number: " + seq) - try { - // Convert each Flume event to a serializable SparkFlumeEvent - val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) - var j = 0 - while (j < events.size()) { - buffer += toSparkFlumeEvent(events(j)) - j += 1 - } - store(buffer) - logDebug("Sending ack for sequence number: " + seq) - // Send an ack to Flume so that Flume discards the events from its channels. - client.ack(seq) - logDebug("Ack sent for sequence number: " + seq) - } catch { - case e: Exception => - try { - // Let Flume know that the events need to be pushed back into the channel. - logDebug("Sending nack for sequence number: " + seq) - client.nack(seq) // If the agent is down, even this could fail and throw - logDebug("Nack sent for sequence number: " + seq) - } catch { - case e: Exception => logError( - "Sending Nack also failed. A Flume agent is down.") - } - TimeUnit.SECONDS.sleep(2L) // for now just leave this as a fixed 2 seconds. - logWarning("Error while attempting to store events", e) - } - } else { - logWarning("Did not receive events from Flume agent due to error on the Flume " + - "agent: " + eventBatch.getErrorMsg) - } - } catch { - case e: Exception => - logWarning("Error while reading data from Flume", e) - } finally { - connections.add(connection) - } - } - } - }) + receiverExecutor.submit(new FlumeBatchFetcher(this)) } } @@ -153,16 +100,12 @@ private[streaming] class FlumePollingReceiver( channelFactory.releaseExternalResources() } - /** - * Utility method to convert [[SparkSinkEvent]] to [[SparkFlumeEvent]] - * @param event - Event to convert to SparkFlumeEvent - * @return - The SparkFlumeEvent generated from SparkSinkEvent - */ - private def toSparkFlumeEvent(event: SparkSinkEvent): SparkFlumeEvent = { - val sparkFlumeEvent = new SparkFlumeEvent() - sparkFlumeEvent.event.setBody(event.getBody) - sparkFlumeEvent.event.setHeaders(event.getHeaders) - sparkFlumeEvent + private[flume] def getConnections: LinkedBlockingQueue[FlumeConnection] = { + this.connections + } + + private[flume] def getMaxBatchSize: Int = { + this.maxBatchSize } } @@ -171,7 +114,7 @@ private[streaming] class FlumePollingReceiver( * @param transceiver The transceiver to use for communication with Flume * @param client The client that the callbacks are received on. */ -private class FlumeConnection(val transceiver: NettyTransceiver, +private[flume] class FlumeConnection(val transceiver: NettyTransceiver, val client: SparkFlumeProtocol.Callback) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 73dffef953309..6ee7ac974b4a0 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -109,11 +109,11 @@ class FlumeStreamSuite extends TestSuiteBase { } class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { - override def newChannel(pipeline:ChannelPipeline) : SocketChannel = { - var encoder : ZlibEncoder = new ZlibEncoder(compressionLevel); - pipeline.addFirst("deflater", encoder); - pipeline.addFirst("inflater", new ZlibDecoder()); - super.newChannel(pipeline); + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) } } } diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4e2275ab238f7..2067c473f0e3f 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index dc48a08c93de2..371f1f1e9d39a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index b93ad016f84f0..1d7dd49d15c22 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 22c1fff23d9a2..7e48968feb3bc 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 5308bb4e440ea..8658ecf5abfab 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index a54b34235dfb4..560244ad93369 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index a8b907b241893..aa917d0575c4c 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -75,7 +75,7 @@ * onto the Kinesis stream. * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. */ -public final class JavaKinesisWordCountASL { +public final class JavaKinesisWordCountASL { // needs to be public for access from run-example private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); @@ -87,10 +87,10 @@ public static void main(String[] args) { /* Check that all required args were passed in. */ if (args.length < 2) { System.err.println( - "|Usage: KinesisWordCount \n" + - "| is the name of the Kinesis stream\n" + - "| is the endpoint of the Kinesis service\n" + - "| (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); + "Usage: JavaKinesisWordCountASL \n" + + " is the name of the Kinesis stream\n" + + " is the endpoint of the Kinesis service\n" + + " (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); System.exit(1); } @@ -130,10 +130,10 @@ public static void main(String[] args) { /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ List> streamsList = new ArrayList>(numStreams); for (int i = 0; i < numStreams; i++) { - streamsList.add( - KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) - ); + streamsList.add( + KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) + ); } /* Union all the streams if there is more than 1 stream */ diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index d03edf8b30a9f..fffd90de08240 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -69,7 +69,7 @@ import org.apache.log4j.Level * dummy data onto the Kinesis stream. * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. */ -object KinesisWordCountASL extends Logging { +private object KinesisWordCountASL extends Logging { def main(args: Array[String]) { /* Check that all required args were passed in. */ if (args.length < 2) { @@ -154,7 +154,7 @@ object KinesisWordCountASL extends Logging { * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ * https://kinesis.us-east-1.amazonaws.com 10 5 */ -object KinesisWordCountProducerASL { +private object KinesisWordCountProducerASL { def main(args: Array[String]) { if (args.length < 4) { System.err.println("Usage: KinesisWordCountProducerASL " + @@ -235,7 +235,7 @@ object KinesisWordCountProducerASL { * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ -object StreamingExamples extends Logging { +private[streaming] object StreamingExamples extends Logging { /** Set reasonable logging levels for streaming if the user has not configured log4j. */ def setStreamingLogLevels() { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 713cac0e293c0..96f4399accd3a 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -35,7 +35,7 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionIn object KinesisUtils { /** * Create an InputDStream that pulls messages from a Kinesis stream. - * + * :: Experimental :: * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -52,6 +52,7 @@ object KinesisUtils { * * @return ReceiverInputDStream[Array[Byte]] */ + @Experimental def createStream( ssc: StreamingContext, streamName: String, @@ -65,9 +66,8 @@ object KinesisUtils { /** * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. - * + * :: Experimental :: * @param jssc Java StreamingContext object - * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. @@ -83,6 +83,7 @@ object KinesisUtils { * * @return JavaReceiverInputDStream[Array[Byte]] */ + @Experimental def createStream( jssc: JavaStreamingContext, streamName: String, diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index a5b162a0482e4..71a078d58a8d8 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 6dd52fc618b1e..3f49b1d63b6e1 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 899a3cbd62b60..5bcb96b136ed7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -37,7 +37,15 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { - partitionsRDD.setName("EdgeRDD") + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("EdgeRDD") override protected def getPartitions: Array[Partition] = partitionsRDD.partitions diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 5e7e72a764cc8..13033fee0e6b5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -91,7 +91,7 @@ object PartitionStrategy { case object EdgePartition1D extends PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val mixingPrime: VertexId = 1125899906842597L - (math.abs(src) * mixingPrime).toInt % numParts + (math.abs(src * mixingPrime) % numParts).toInt } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 4825d12fc27b3..04fbc9dbab8d1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -108,7 +108,7 @@ class VertexRDD[@specialized VD: ClassTag]( /** The number of vertices in the RDD. */ override def count(): Long = { - partitionsRDD.map(_.size).reduce(_ + _) + partitionsRDD.map(_.size.toLong).reduce(_ + _) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 60149548ab852..b8309289fe475 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -40,7 +40,7 @@ object GraphGenerators { val RMATd = 0.25 /** - * Generate a graph whose vertex out degree is log normal. + * Generate a graph whose vertex out degree distribution is log normal. * * The default values for mu and sigma are taken from the Pregel paper: * @@ -48,33 +48,36 @@ object GraphGenerators { * Ilan Horn, Naty Leiser, and Grzegorz Czajkowski. 2010. * Pregel: a system for large-scale graph processing. SIGMOD '10. * - * @param sc - * @param numVertices - * @param mu - * @param sigma - * @return + * If the seed is -1 (default), a random seed is chosen. Otherwise, use + * the user-specified seed. + * + * @param sc Spark Context + * @param numVertices number of vertices in generated graph + * @param numEParts (optional) number of partitions + * @param mu (optional, default: 4.0) mean of out-degree distribution + * @param sigma (optional, default: 1.3) standard deviation of out-degree distribution + * @param seed (optional, default: -1) seed for RNGs, -1 causes a random seed to be chosen + * @return Graph object */ - def logNormalGraph(sc: SparkContext, numVertices: Int, numEParts: Int, - mu: Double = 4.0, sigma: Double = 1.3): Graph[Long, Int] = { - val vertices = sc.parallelize(0 until numVertices, numEParts).map { src => - // Initialize the random number generator with the source vertex id - val rand = new Random(src) - val degree = math.min(numVertices.toLong, math.exp(rand.nextGaussian() * sigma + mu).toLong) - (src.toLong, degree) + def logNormalGraph( + sc: SparkContext, numVertices: Int, numEParts: Int = 0, mu: Double = 4.0, + sigma: Double = 1.3, seed: Long = -1): Graph[Long, Int] = { + + val evalNumEParts = if (numEParts == 0) sc.defaultParallelism else numEParts + + // Enable deterministic seeding + val seedRand = if (seed == -1) new Random() else new Random(seed) + val seed1 = seedRand.nextInt() + val seed2 = seedRand.nextInt() + + val vertices: RDD[(VertexId, Long)] = sc.parallelize(0 until numVertices, evalNumEParts).map { + src => (src, sampleLogNormal(mu, sigma, numVertices, seed = (seed1 ^ src))) } + val edges = vertices.flatMap { case (src, degree) => - new Iterator[Edge[Int]] { - // Initialize the random number generator with the source vertex id - val rand = new Random(src) - var i = 0 - override def hasNext(): Boolean = { i < degree } - override def next(): Edge[Int] = { - val nextEdge = Edge[Int](src, rand.nextInt(numVertices), i) - i += 1 - nextEdge - } - } + generateRandomEdges(src.toInt, degree.toInt, numVertices, seed = (seed2 ^ src)) } + Graph(vertices, edges, 0) } @@ -82,9 +85,10 @@ object GraphGenerators { // the edge data is the weight (default 1) val RMATc = 0.15 - def generateRandomEdges(src: Int, numEdges: Int, maxVertexId: Int): Array[Edge[Int]] = { - val rand = new Random() - Array.fill(maxVertexId) { Edge[Int](src, rand.nextInt(maxVertexId), 1) } + def generateRandomEdges( + src: Int, numEdges: Int, maxVertexId: Int, seed: Long = -1): Array[Edge[Int]] = { + val rand = if (seed == -1) new Random() else new Random(seed) + Array.fill(numEdges) { Edge[Int](src, rand.nextInt(maxVertexId), 1) } } /** @@ -97,9 +101,12 @@ object GraphGenerators { * @param mu the mean of the normal distribution * @param sigma the standard deviation of the normal distribution * @param maxVal exclusive upper bound on the value of the sample + * @param seed optional seed */ - private def sampleLogNormal(mu: Double, sigma: Double, maxVal: Int): Int = { - val rand = new Random() + private[spark] def sampleLogNormal( + mu: Double, sigma: Double, maxVal: Int, seed: Long = -1): Int = { + val rand = if (seed == -1) new Random() else new Random(seed) + val sigmaSq = sigma * sigma val m = math.exp(mu + sigmaSq / 2.0) // expm1 is exp(m)-1 with better accuracy for tiny m diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala new file mode 100644 index 0000000000000..b346d4db2ef96 --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.util + +import org.scalatest.FunSuite + +import org.apache.spark.graphx.LocalSparkContext + +class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { + + test("GraphGenerators.generateRandomEdges") { + val src = 5 + val numEdges10 = 10 + val numEdges20 = 20 + val maxVertexId = 100 + + val edges10 = GraphGenerators.generateRandomEdges(src, numEdges10, maxVertexId) + assert(edges10.length == numEdges10) + + val correctSrc = edges10.forall(e => e.srcId == src) + assert(correctSrc) + + val correctWeight = edges10.forall(e => e.attr == 1) + assert(correctWeight) + + val correctRange = edges10.forall(e => e.dstId >= 0 && e.dstId <= maxVertexId) + assert(correctRange) + + val edges20 = GraphGenerators.generateRandomEdges(src, numEdges20, maxVertexId) + assert(edges20.length == numEdges20) + + val edges10_round1 = + GraphGenerators.generateRandomEdges(src, numEdges10, maxVertexId, seed = 12345) + val edges10_round2 = + GraphGenerators.generateRandomEdges(src, numEdges10, maxVertexId, seed = 12345) + assert(edges10_round1.zip(edges10_round2).forall { case (e1, e2) => + e1.srcId == e2.srcId && e1.dstId == e2.dstId && e1.attr == e2.attr + }) + + val edges10_round3 = + GraphGenerators.generateRandomEdges(src, numEdges10, maxVertexId, seed = 3467) + assert(!edges10_round1.zip(edges10_round3).forall { case (e1, e2) => + e1.srcId == e2.srcId && e1.dstId == e2.dstId && e1.attr == e2.attr + }) + } + + test("GraphGenerators.sampleLogNormal") { + val mu = 4.0 + val sigma = 1.3 + val maxVal = 100 + + val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal) + assert(dstId < maxVal) + + val dstId_round1 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345) + val dstId_round2 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345) + assert(dstId_round1 == dstId_round2) + + val dstId_round3 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 789) + assert(dstId_round1 != dstId_round3) + } + + test("GraphGenerators.logNormalGraph") { + withSpark { sc => + val mu = 4.0 + val sigma = 1.3 + val numVertices100 = 100 + + val graph = GraphGenerators.logNormalGraph(sc, numVertices100, mu = mu, sigma = sigma) + assert(graph.vertices.count() == numVertices100) + + val graph_round1 = + GraphGenerators.logNormalGraph(sc, numVertices100, mu = mu, sigma = sigma, seed = 12345) + val graph_round2 = + GraphGenerators.logNormalGraph(sc, numVertices100, mu = mu, sigma = sigma, seed = 12345) + + val graph_round1_edges = graph_round1.edges.collect() + val graph_round2_edges = graph_round2.edges.collect() + + assert(graph_round1_edges.zip(graph_round2_edges).forall { case (e1, e2) => + e1.srcId == e2.srcId && e1.dstId == e2.dstId && e1.attr == e2.attr + }) + + val graph_round3 = + GraphGenerators.logNormalGraph(sc, numVertices100, mu = mu, sigma = sigma, seed = 567) + + val graph_round3_edges = graph_round3.edges.collect() + + assert(!graph_round1_edges.zip(graph_round3_edges).forall { case (e1, e2) => + e1.srcId == e2.srcId && e1.dstId == e2.dstId && e1.attr == e2.attr + }) + } + } + +} diff --git a/make-distribution.sh b/make-distribution.sh index f7a6a9d838bb6..9b012b9222db4 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -28,7 +28,7 @@ set -o pipefail set -e # Figure out where the Spark framework is installed -FWDIR="$(cd `dirname $0`; pwd)" +FWDIR="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$FWDIR/dist" SPARK_TACHYON=false @@ -50,7 +50,8 @@ while (( "$#" )); do case $1 in --hadoop) echo "Error: '--hadoop' is no longer supported:" - echo "Error: use Maven options -Phadoop.version and -Pyarn.version" + echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." + echo "Error: Related profiles include hadoop-0.23, hdaoop-2.2, hadoop-2.3 and hadoop-2.4." exit_with_usage ;; --with-yarn) @@ -113,7 +114,17 @@ if ! which mvn &>/dev/null; then echo -e "Download Maven from https://maven.apache.org/" exit -1; fi + VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | tail -n 1) +SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles $@ 2>/dev/null\ + | grep -v "INFO"\ + | fgrep --count "hive";\ + # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ + # because we use "set -o pipefail" + echo -n) JAVA_CMD="$JAVA_HOME"/bin/java JAVA_VERSION=$("$JAVA_CMD" -version 2>&1) @@ -175,7 +186,7 @@ cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" mkdir -p "$DISTDIR/examples/src/main" cp -r "$FWDIR"/examples/src/main "$DISTDIR/examples/src/" -if [ "$SPARK_HIVE" == "true" ]; then +if [ "$SPARK_HIVE" == "1" ]; then cp "$FWDIR"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" fi @@ -209,10 +220,10 @@ if [ "$SPARK_TACHYON" == "true" ]; then wget "$TACHYON_URL" tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" - cp "tachyon-${TACHYON_VERSION}/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" + cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/src/main/java/tachyon/web/resources "$DISTDIR/tachyon/src/main/java/tachyon/web" + cp -r "tachyon-${TACHYON_VERSION}"/core/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" if [[ `uname -a` == Darwin* ]]; then # need to run sed differently on osx diff --git a/mllib/pom.xml b/mllib/pom.xml index c7a1e2ae75c84..a5eeef88e9d62 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 5823cb6e52e7f..12a3d91cd31a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -25,7 +25,7 @@ import org.apache.spark.mllib.linalg.Vector /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ -class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Serializable { +class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable { /** Total number of clusters. */ def k: Int = clusterCenters.length diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 132b3af72d9ce..ac6eaea3f43ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -130,7 +130,7 @@ class IndexedRowMatrix( val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) => IndexedRow(i, v) } - new IndexedRowMatrix(indexedRows, nRows, nCols) + new IndexedRowMatrix(indexedRows, nRows, B.numCols) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index fdd67160114ca..45dbf6044fcc5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -128,7 +128,7 @@ class LeastSquaresGradient extends Gradient { class HingeGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { val dotProduct = dot(data, weights) - // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { @@ -146,7 +146,7 @@ class HingeGradient extends Gradient { weights: Vector, cumGradient: Vector): Double = { val dotProduct = dot(data, weights) - // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 5cdd258f6c20b..98596569b8c95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,8 +28,9 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impl._ import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} +import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -65,36 +66,41 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) logDebug("algo = " + strategy.algo) + logDebug("maxBins = " + metadata.maxBins) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) - val numBins = bins(0).length timer.stop("findSplitsBins") - logDebug("numBins = " + numBins) + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) .persist(StorageLevel.MEMORY_AND_DISK) - val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth - // the max number of nodes possible given the depth of the tree - val maxNumNodes = (2 << maxDepth) - 1 + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1 + val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1) // Initialize an array to hold parent impurity calculations for each node. - val parentImpurities = new Array[Double](maxNumNodes) + val parentImpurities = new Array[Double](maxNumNodesPlus1) // dummy value for top node (updated during first split calculation) - val nodes = new Array[Node](maxNumNodes) + val nodes = new Array[Node](maxNumNodesPlus1) // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) + // TODO: Calculate memory usage more precisely. + val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -124,26 +130,30 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.start("findBestSplits") - val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, - metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) + val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] = + DecisionTree.findBestSplits(treeInput, parentImpurities, + metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") - val levelNodeIndexOffset = (1 << level) - 1 + val levelNodeIndexOffset = Node.startIndexInLevel(level) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { val nodeIndex = levelNodeIndexOffset + index - val isLeftChild = level != 0 && nodeIndex % 2 == 1 - val parentNodeIndex = if (isLeftChild) { // -1 for root node - (nodeIndex - 1) / 2 - } else { - (nodeIndex - 2) / 2 - } + // Extract info for this node (index) at the current level. timer.start("extractNodeInfo") - extractNodeInfo(nodeSplitStats, level, index, nodes) + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val predict = nodeSplitStats._3.predict + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) + val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node timer.stop("extractNodeInfo") + if (level != 0) { // Set parent. - if (isLeftChild) { + val parentNodeIndex = Node.parentIndex(nodeIndex) + if (Node.isLeftChild(nodeIndex)) { nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) } else { nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) @@ -151,11 +161,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } // Extract info for nodes at the next lower level. timer.start("extractInfoForLowerLevels") - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) + if (level < maxDepth) { + val leftChildIndex = Node.leftChildIndex(nodeIndex) + val leftImpurity = stats.leftImpurity + logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity) + parentImpurities(leftChildIndex) = leftImpurity + + val rightChildIndex = Node.rightChildIndex(nodeIndex) + val rightImpurity = stats.rightImpurity + logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity) + parentImpurities(rightChildIndex) = rightImpurity + } timer.stop("extractInfoForLowerLevels") - logDebug("final best split = " + nodeSplitStats._1) + logDebug("final best split = " + split) } - require((1 << level) == splitsStatsForLevel.length) + require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -171,7 +191,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Initialize the top or root node of the tree. - val topNode = nodes(0) + val topNode = nodes(1) // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) @@ -183,47 +203,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo new DecisionTreeModel(topNode, strategy.algo) } - /** - * Extract the decision tree node information for the given tree level and node index - */ - private def extractNodeInfo( - nodeSplitStats: (Split, InformationGainStats), - level: Int, - index: Int, - nodes: Array[Node]): Unit = { - val split = nodeSplitStats._1 - val stats = nodeSplitStats._2 - val nodeIndex = (1 << level) - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - nodes(nodeIndex) = node - } - - /** - * Extract the decision tree node information for the children of the node - */ - private def extractInfoForLowerLevels( - level: Int, - index: Int, - maxDepth: Int, - nodeSplitStats: (Split, InformationGainStats), - parentImpurities: Array[Double]): Unit = { - - if (level >= maxDepth) { - return - } - - val leftNodeIndex = (2 << level) - 1 + 2 * index - val leftImpurity = nodeSplitStats._2.leftImpurity - logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) - parentImpurities(leftNodeIndex) = leftImpurity - - val rightNodeIndex = leftNodeIndex + 1 - val rightImpurity = nodeSplitStats._2.rightImpurity - logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity) - parentImpurities(rightNodeIndex) = rightImpurity - } } object DecisionTree extends Serializable with Logging { @@ -352,9 +331,9 @@ object DecisionTree extends Serializable with Logging { * Supported values: "gini" (recommended) or "entropy". * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) + * (suggested value: 5) * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) + * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction */ def trainClassifier( @@ -396,9 +375,9 @@ object DecisionTree extends Serializable with Logging { * Supported values: "variance". * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) + * (suggested value: 5) * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) + * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction */ def trainRegressor( @@ -425,9 +404,6 @@ object DecisionTree extends Serializable with Logging { impurity, maxDepth, maxBins) } - - private val InvalidBinIndex = -1 - /** * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. @@ -436,12 +412,12 @@ object DecisionTree extends Serializable with Logging { * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param splits possible splits for all features - * @param bins possible bins for all features + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array (over nodes) of splits with best split for each node at a given level. */ - protected[tree] def findBestSplits( + private[tree] def findBestSplits( input: RDD[TreePoint], parentImpurities: Array[Double], metadata: DecisionTreeMetadata, @@ -450,7 +426,7 @@ object DecisionTree extends Serializable with Logging { splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, @@ -459,7 +435,7 @@ object DecisionTree extends Serializable with Logging { // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) - var bestSplits = new Array[(Split, InformationGainStats)](0) + var bestSplits = new Array[(Split, InformationGainStats, Predict)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { @@ -474,6 +450,138 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a node + * at the current level being trained; that node's index is returned. + * + * @param node Node in tree from which to classify the given data point. + * @param binnedFeatures Binned feature vector for data point. + * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param unorderedFeatures Set of indices of unordered features. + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * set of nodes in a (level, group). + */ + private def predictNodeIndex( + node: Node, + binnedFeatures: Array[Int], + bins: Array[Array[Bin]], + unorderedFeatures: Set[Int]): Int = { + if (node.isLeaf) { + node.id + } else { + val featureIndex = node.split.get.feature + val splitLeft = node.split.get.featureType match { + case Continuous => { + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] + // We do not need to check lowSplit since bins are separated by splits. + featureValueUpperBound <= node.split.get.threshold + } + case Categorical => { + val featureValue = binnedFeatures(featureIndex) + node.split.get.categories.contains(featureValue) + } + case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") + } + if (node.leftNode.isEmpty || node.rightNode.isEmpty) { + // Return index from next layer of nodes to train + if (splitLeft) { + Node.leftChildIndex(node.id) + } else { + Node.rightChildIndex(node.id) + } + } else { + if (splitLeft) { + predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures) + } else { + predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures) + } + } + } + } + + /** + * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. + * + * For ordered features, a single bin is updated. + * For unordered features, bins correspond to subsets of categories; either the left or right bin + * for each subset is updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param unorderedFeatures Set of indices of unordered features. + */ + private def mixedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + nodeIndex: Int, + bins: Array[Array[Bin]], + unorderedFeatures: Set[Int]): Unit = { + // Iterate over all features. + val numFeatures = treePoint.binnedFeatures.size + val nodeOffset = agg.getNodeOffset(nodeIndex) + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (unorderedFeatures.contains(featureIndex)) { + // Unordered feature + val featureValue = treePoint.binnedFeatures(featureIndex) + val (leftNodeFeatureOffset, rightNodeFeatureOffset) = + agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + // Update the left or right bin for each split. + val numSplits = agg.numSplits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { + agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label) + } else { + agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label) + } + splitIndex += 1 + } + } else { + // Ordered feature + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label) + } + featureIndex += 1 + } + } + + /** + * Helper for binSeqOp, for regression and for classification with only ordered features. + * + * For each feature, the sufficient statistics of one bin are updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @return agg + */ + private def orderedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + nodeIndex: Int): Unit = { + val label = treePoint.label + val nodeOffset = agg.getNodeOffset(nodeIndex) + // Iterate over all features. + val numFeatures = agg.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label) + featureIndex += 1 + } + } + /** * Returns an array of optimal splits for a group of nodes at a given level * @@ -481,8 +589,9 @@ object DecisionTree extends Serializable with Logging { * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param splits possible splits for all features - * @param bins possible bins for all features, indexed as (numFeatures)(numBins) + * @param nodes Array of all nodes in the tree. Used for matching data points to nodes. + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. @@ -497,7 +606,7 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]], timer: TimeTracker, numGroups: Int = 1, - groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { + groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -527,88 +636,22 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = (1 << level) / numGroups + val numNodes = Node.maxNodesInLevel(level) / numGroups logDebug("numNodes = " + numNodes) - // Find the number of features by looking at the first sample. - val numFeatures = metadata.numFeatures - logDebug("numFeatures = " + numFeatures) - - // numBins: Number of bins = 1 + number of possible splits - val numBins = bins(0).length - logDebug("numBins = " + numBins) - - val numClasses = metadata.numClasses - logDebug("numClasses = " + numClasses) - - val isMulticlass = metadata.isMulticlass - logDebug("isMulticlass = " + isMulticlass) - - val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) + logDebug("numFeatures = " + metadata.numFeatures) + logDebug("numClasses = " + metadata.numClasses) + logDebug("isMulticlass = " + metadata.isMulticlass) + logDebug("isMulticlassWithCategoricalFeatures = " + + metadata.isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - /** - * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a node - * at the current level being trained; that node's index is returned. - * - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - */ - def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { - if (node.isLeaf) { - node.id - } else { - val featureIndex = node.split.get.feature - val splitLeft = node.split.get.featureType match { - case Continuous => { - val binIndex = binnedFeatures(featureIndex) - val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold - // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] - // We do not need to check lowSplit since bins are separated by splits. - featureValueUpperBound <= node.split.get.threshold - } - case Categorical => { - val featureValue = if (metadata.isUnordered(featureIndex)) { - binnedFeatures(featureIndex) - } else { - val binIndex = binnedFeatures(featureIndex) - bins(featureIndex)(binIndex).category - } - node.split.get.categories.contains(featureValue) - } - case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") - } - if (node.leftNode.isEmpty || node.rightNode.isEmpty) { - // Return index from next layer of nodes to train - if (splitLeft) { - node.id * 2 + 1 // left - } else { - node.id * 2 + 2 // right - } - } else { - if (splitLeft) { - predictNodeIndex(node.leftNode.get, binnedFeatures) - } else { - predictNodeIndex(node.rightNode.get, binnedFeatures) - } - } - } - } - - def nodeIndexToLevel(idx: Int): Int = { - if (idx == 0) { - 0 - } else { - math.floor(math.log(idx) / math.log(2)).toInt - } - } - - // Used for treePointToNodeIndex - val levelOffset = (1 << level) - 1 + // Used for treePointToNodeIndex to get an index for this (level, group). + // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level. + // - groupShift corrects for groups in this level before the current group. + val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift /** * Find the node index for the given example. @@ -619,661 +662,287 @@ object DecisionTree extends Serializable with Logging { if (level == 0) { 0 } else { - val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures) - // Get index for this (level, group). - globalNodeIndex - levelOffset - groupShift - } - } - - /** - * Increment aggregate in location for (node, feature, bin, label). - * - * @param treePoint Data point being aggregated. - * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes. - * Indexed by (node, feature, bin, label) where label is the least significant bit. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def updateBinForOrderedFeature( - treePoint: TreePoint, - agg: Array[Double], - nodeIndex: Int, - featureIndex: Int): Unit = { - // Update the left or right count for one bin. - val aggIndex = - numClasses * numBins * numFeatures * nodeIndex + - numClasses * numBins * featureIndex + - numClasses * treePoint.binnedFeatures(featureIndex) + - treePoint.label.toInt - agg(aggIndex) += 1 - } - - /** - * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), - * where [bins] ranges over all bins. - * Updates left or right side of aggregate depending on split. - * - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @param treePoint Data point being aggregated. - * @param agg Indexed by (left/right, node, feature, bin, label) - * where label is the least significant bit. - * The left/right specifier is a 0/1 index indicating left/right child info. - * @param rightChildShift Offset for right side of agg. - */ - def updateBinForUnorderedFeature( - nodeIndex: Int, - featureIndex: Int, - treePoint: TreePoint, - agg: Array[Double], - rightChildShift: Int): Unit = { - val featureValue = treePoint.binnedFeatures(featureIndex) - // Update the left or right count for one bin. - val aggShift = - numClasses * numBins * numFeatures * nodeIndex + - numClasses * numBins * featureIndex + - treePoint.label.toInt - // Find all matching bins and increment their values - val featureCategories = metadata.featureArity(featureIndex) - val numCategoricalBins = (1 << featureCategories - 1) - 1 - var binIndex = 0 - while (binIndex < numCategoricalBins) { - val aggIndex = aggShift + binIndex * numClasses - if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { - agg(aggIndex) += 1 - } else { - agg(rightChildShift + aggIndex) += 1 - } - binIndex += 1 - } - } - - /** - * Helper for binSeqOp. - * - * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes. - * Indexed by (node, feature, bin, label) where label is the least significant bit. - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def binaryOrNotCategoricalBinSeqOp( - agg: Array[Double], - treePoint: TreePoint, - nodeIndex: Int): Unit = { - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) - featureIndex += 1 - } - } - - val rightChildShift = numClasses * numBins * numFeatures * numNodes - - /** - * Helper for binSeqOp. - * - * @param agg Array storing aggregate calculation. - * For ordered features, this is of size: - * numClasses * numBins * numFeatures * numNodes. - * For unordered features, this is of size: - * 2 * numClasses * numBins * numFeatures * numNodes. - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def multiclassWithCategoricalBinSeqOp( - agg: Array[Double], - treePoint: TreePoint, - nodeIndex: Int): Unit = { - val label = treePoint.label - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isUnordered(featureIndex)) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) - } else { - updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) - } - featureIndex += 1 - } - } - - /** - * Performs a sequential aggregation over a partition for regression. - * For l nodes, k features, - * the count, sum, sum of squares of one of the p bins is incremented. - * - * @param agg Array storing aggregate calculation, updated by this function. - * Size: 3 * numBins * numFeatures * numNodes - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @return agg - */ - def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { - val label = treePoint.label - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Update count, sum, and sum^2 for one bin. - val binIndex = treePoint.binnedFeatures(featureIndex) - val aggIndex = - 3 * numBins * numFeatures * nodeIndex + - 3 * numBins * featureIndex + - 3 * binIndex - agg(aggIndex) += 1 - agg(aggIndex + 1) += label - agg(aggIndex + 2) += label * label - featureIndex += 1 + val globalNodeIndex = + predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures) + globalNodeIndex - globalNodeIndexOffset } } /** * Performs a sequential aggregation over a partition. - * For l nodes, k features, - * For classification: - * Either the left count or the right count of one of the bins is - * incremented based upon whether the feature is classified as 0 or 1. - * For regression: - * The count, sum, sum of squares of one of the bins is incremented. * - * @param agg Array storing aggregate calculation, updated by this function. - * Size for classification: - * numClasses * numBins * numFeatures * numNodes for ordered features, or - * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. - * Size for regression: - * 3 * numBins * numFeatures * numNodes. + * Each data point contributes to one node. For each feature, + * the aggregate sufficient statistics are updated for the relevant bins. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). * @param treePoint Data point being aggregated. * @return agg */ - def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + def binSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint): DTStatsAggregator = { val nodeIndex = treePointToNodeIndex(treePoint) // If the example does not reach this level, then nodeIndex < 0. // If the example reaches this level but is handled in a different group, // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). if (nodeIndex >= 0 && nodeIndex < numNodes) { - if (metadata.isClassification) { - if (isMulticlassWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } else { - binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg, treePoint, nodeIndex) } else { - regressionBinSeqOp(agg, treePoint, nodeIndex) + mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures) } } agg } - // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins) - logDebug("binAggregateLength = " + binAggregateLength) - - /** - * Combines the aggregates from partitions. - * @param agg1 Array containing aggregates from one or more partitions - * @param agg2 Array containing aggregates from one or more partitions - * @return Combined aggregate from agg1 and agg2 - */ - def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { - var index = 0 - val combinedAggregate = new Array[Double](binAggregateLength) - while (index < binAggregateLength) { - combinedAggregate(index) = agg1(index) + agg2(index) - index += 1 - } - combinedAggregate - } - // Calculate bin aggregates. timer.start("aggregation") - val binAggregates = { - input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) + val binAggregates: DTStatsAggregator = { + val initAgg = new DTStatsAggregator(metadata, numNodes) + input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp) } timer.stop("aggregation") - logDebug("binAggregates.length = " + binAggregates.length) - - /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftNodeAgg left node aggregates for this (feature, split) - * @param rightNodeAgg right node aggregate for this (feature, split) - * @param topImpurity impurity of the parent node - * @return information gain and statistics for all splits - */ - def calculateGainForSplit( - leftNodeAgg: Array[Double], - rightNodeAgg: Array[Double], - topImpurity: Double): InformationGainStats = { - if (metadata.isClassification) { - val leftTotalCount = leftNodeAgg.sum - val rightTotalCount = rightNodeAgg.sum - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) - classIndex += 1 - } - metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) - } - } - - val totalCount = leftTotalCount + rightTotalCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) - } - - // Sum of count for each label - val leftrightNodeAgg: Array[Double] = - leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => - leftCount + rightCount - } - - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if (currentValue > maxValue) { - (currentIndex, currentValue, currentIndex + 1) - } else { - (maxIndex, maxValue, currentIndex + 1) - } - } - if (result._1 < 0) { - throw new RuntimeException("DecisionTree internal error:" + - " calculateGainForSplit failed in indexOfLargestArrayElement") - } - result._1 - } - - val predict = indexOfLargestArrayElement(leftrightNodeAgg) - val prob = leftrightNodeAgg(predict) / totalCount - val leftImpurity = if (leftTotalCount == 0) { - topImpurity - } else { - metadata.impurity.calculate(leftNodeAgg, leftTotalCount) - } - val rightImpurity = if (rightTotalCount == 0) { - topImpurity - } else { - metadata.impurity.calculate(rightNodeAgg, rightTotalCount) - } - - val leftWeight = leftTotalCount / totalCount - val rightWeight = rightTotalCount / totalCount - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) - - } else { - // Regression + // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") + val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes) + // Iterating over all nodes at this level + var nodeIndex = 0 + while (nodeIndex < numNodes) { + val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex) + logDebug("node impurity = " + nodeImpurity) + bestSplits(nodeIndex) = + binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits) + logDebug("best split = " + bestSplits(nodeIndex)._1) + nodeIndex += 1 + } + timer.stop("chooseSplits") - val leftCount = leftNodeAgg(0) - val leftSum = leftNodeAgg(1) - val leftSumSquares = leftNodeAgg(2) + bestSplits + } - val rightCount = rightNodeAgg(0) - val rightSum = rightNodeAgg(1) - val rightSumSquares = rightNodeAgg(2) + /** + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + private def calculateGainForSplit( + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + topImpurity: Double, + level: Int, + metadata: DecisionTreeMetadata): InformationGainStats = { + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count + + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return InformationGainStats.invalidInformationGainStats + } - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - metadata.impurity.calculate(count, sum, sumSquares) - } - } + val totalCount = leftCount + rightCount - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, - Double.MinValue, leftSum / leftCount) - } + // impurity of parent node + val impurity = if (level > 0) { + topImpurity + } else { + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) + parentNodeAgg.calculate() + } - val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares) + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - } + // if information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return InformationGainStats.invalidInformationGainStats } - /** - * Extracts left and right split aggregates. - * @param binData Aggregate array slice from getBinDataForNode. - * For classification: - * For unordered features, this is leftChildData ++ rightChildData, - * each of which is indexed by (feature, split/bin, class), - * with class being the least significant bit. - * For ordered features, this is of size numClasses * numBins * numFeatures. - * For regression: - * This is of size 2 * numFeatures * numBins. - * @return (leftNodeAgg, rightNodeAgg) pair of arrays. - * For classification, each array is of size (numFeatures, (numBins - 1), numClasses). - * For regression, each array is of size (numFeatures, (numBins - 1), 3). - * - */ - def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { - - - /** - * The input binData is indexed as (feature, bin, class). - * This computes cumulative sums over splits. - * Each (feature, class) pair is handled separately. - * Note: numSplits = numBins - 1. - * @param leftNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 0, ..., numSplits - 2) is set to be - * the cumulative sum (from left) over binData for bins 0, ..., i. - * @param rightNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 1, ..., numSplits - 1) is set to be - * the cumulative sum (from right) over binData for bins - * numBins - 1, ..., numBins - 1 - i. - */ - def findAggForOrderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins - - var classIndex = 0 - while (classIndex < numClasses) { - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(classIndex) - = binData(shift + (numClasses * (numBins - 1)) + classIndex) - classIndex += 1 - } + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + } - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - var innerClassIndex = 0 - while (innerClassIndex < numClasses) { - leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) - = binData(shift + numClasses * splitIndex + innerClassIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = - binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) - innerClassIndex += 1 - } - splitIndex += 1 - } - } + /** + * Calculate predict value for current node, given stats of any split. + * Note that this function is called only once for each node. + * @param leftImpurityCalculator left node aggregates for a split + * @param rightImpurityCalculator right node aggregates for a node + * @return predict value for current node + */ + private def calculatePredict( + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator): Predict = { + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) + val predict = parentNodeAgg.predict + val prob = parentNodeAgg.prob(predict) + + new Predict(predict, prob) + } - /** - * Reshape binData for this feature. - * Indexes binData as (feature, split, class) with class as the least significant bit. - * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value - */ - def findAggForUnorderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - val rightChildShift = numClasses * numBins * numFeatures - var splitIndex = 0 - while (splitIndex < numBins - 1) { - var classIndex = 0 - while (classIndex < numClasses) { - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins + splitIndex * numClasses - val leftBinValue = binData(shift + classIndex) - val rightBinValue = binData(rightChildShift + shift + classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue - rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue - classIndex += 1 - } - splitIndex += 1 - } - } + /** + * Find the best split for a node. + * @param binAggregates Bin statistics. + * @param nodeIndex Index for node to split in this (level, group). + * @param nodeImpurity Impurity of the node (nodeIndex). + * @return tuple for best split: (Split, information gain) + */ + private def binsToBestSplit( + binAggregates: DTStatsAggregator, + nodeIndex: Int, + nodeImpurity: Double, + level: Int, + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = { - def findAggForRegression( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - // shift for this featureIndex - val shift = 3 * featureIndex * numBins - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(numBins - 2)(2) = - binData(shift + (3 * (numBins - 1)) + 2) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - var i = 0 // index for regression histograms - while (i < 3) { // count, sum, sum^2 - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + - leftNodeAgg(featureIndex)(splitIndex - 1)(i) - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = - binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) - i += 1 - } - splitIndex += 1 - } - } + logDebug("node impurity = " + nodeImpurity) - if (metadata.isClassification) { - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isUnordered(featureIndex)) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } else { - // Regression - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } - } + // calculate predict only once + var predict: Option[Predict] = None - /** - * Calculates information gain for all nodes splits. - */ - def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - nodeImpurity: Double): Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - - var featureIndex = 0 - while (featureIndex < numFeatures) { - val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + // For each (feature, split), calculate the gain, and select the best (feature, split). + val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex => + val numSplits = metadata.numSplits(featureIndex) + if (metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) var splitIndex = 0 - while (splitIndex < numSplitsForFeature) { - gains(featureIndex)(splitIndex) = - calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), - rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) + while (splitIndex < numSplits) { + binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } - featureIndex += 1 - } - gains - } - - /** - * Get the number of splits for a feature. - */ - def getNumSplitsForFeature(featureIndex: Int): Int = { - if (metadata.isContinuous(featureIndex)) { - numBins - 1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIdx, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = + binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { - // Categorical feature - val featureCategories = metadata.featureArity(featureIndex) - if (metadata.isUnordered(featureIndex)) { - (1 << featureCategories - 1) - 1 - } else { - featureCategories - } - } - } - - /** - * Find the best split for a node. - * @param binData Bin data slice for this node, given by getBinDataForNode. - * @param nodeImpurity impurity of the top node - * @return tuple of split and information gain - */ - def binsToBestSplit( - binData: Array[Double], - nodeImpurity: Double): (Split, InformationGainStats) = { - - logDebug("node impurity = " + nodeImpurity) - - // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) - - // Calculate gains for all splits. - val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - - val (bestFeatureIndex, bestSplitIndex, gainStats) = { - // Initialize with infeasible values. - var bestFeatureIndex = Int.MinValue - var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) - // Iterate over features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Iterate over all splits. - var splitIndex = 0 - val numSplitsForFeature = getNumSplitsForFeature(featureIndex) - while (splitIndex < numSplitsForFeature) { - val gainStats = gains(featureIndex)(splitIndex) - if (gainStats.gain > bestGainStats.gain) { - bestGainStats = gainStats - bestFeatureIndex = featureIndex - bestSplitIndex = splitIndex + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val numBins = metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = if (metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.calculate() + } else { + Double.MaxValue } - splitIndex += 1 + (featureValue, centroid) + } + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // the bins are ordered by the centroid of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.predict + } else { + Double.MaxValue + } + (featureValue, centroid) } - featureIndex += 1 } - (bestFeatureIndex, bestSplitIndex, bestGainStats) - } - logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - (splits(bestFeatureIndex)(bestSplitIndex), gainStats) - } + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - /** - * Get bin data for one node. - */ - def getBinDataForNode(node: Int): Array[Double] = { - if (metadata.isClassification) { - if (isMulticlassWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 } - } else { - // Regression - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) - binsForNode + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) + (bestFeatureSplit, bestFeatureGainStats) } - } + }.maxBy(_._2.gain) - // Calculate best splits for all nodes at a given level - timer.start("chooseSplits") - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) - // Iterating over all nodes at this level - var node = 0 - while (node < numNodes) { - val nodeImpurityIndex = (1 << level) - 1 + node + groupShift - val binsForNode: Array[Double] = getBinDataForNode(node) - logDebug("nodeImpurityIndex = " + nodeImpurityIndex) - val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("parent node impurity = " + parentNodeImpurity) - bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) - node += 1 - } - timer.stop("chooseSplits") + require(predict.isDefined, "must calculate predict for each node") - bestSplits + (bestSplit, bestSplitStats, predict.get) } /** * Get the number of values to be stored per node in the bin aggregates. - * - * @param numBins Number of bins = 1 + number of possible splits. */ - private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = { + private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = { + val totalBins = metadata.numBins.sum if (metadata.isClassification) { - if (metadata.isMulticlassWithCategoricalFeatures) { - 2 * metadata.numClasses * numBins * metadata.numFeatures - } else { - metadata.numClasses * numBins * metadata.numFeatures - } + metadata.numClasses * totalBins } else { - 3 * numBins * metadata.numFeatures + 3 * totalBins } } @@ -1284,6 +953,7 @@ object DecisionTree extends Serializable with Logging { * Continuous features: * For each feature, there are numBins - 1 possible splits representing the possible binary * decisions at each node in the tree. + * This finds locations (feature values) for splits using a subsample of the data. * * Categorical features: * For each feature, there is 1 bin per split. @@ -1292,7 +962,6 @@ object DecisionTree extends Serializable with Logging { * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are (1 << maxFeatureValue - 1) - 1 splits. * (b) "ordered features" * For regression and binary classification, * and for multiclass classification with a high-arity feature, @@ -1302,7 +971,7 @@ object DecisionTree extends Serializable with Logging { * @param metadata Learning and dataset metadata * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] - * of size (numFeatures, numBins - 1). + * of size (numFeatures, numSplits). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] * of size (numFeatures, numBins). */ @@ -1310,84 +979,80 @@ object DecisionTree extends Serializable with Logging { input: RDD[LabeledPoint], metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { - val count = input.count() + logDebug("isMulticlass = " + metadata.isMulticlass) - // Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.size - - val maxBins = metadata.maxBins - val numBins = if (maxBins <= count) maxBins else count.toInt - logDebug("numBins = " + numBins) - val isMulticlass = metadata.isMulticlass - logDebug("isMulticlass = " + isMulticlass) - - /* - * Ensure numBins is always greater than the categories. For multiclass classification, - * numBins should be greater than 2^(maxCategories - 1) - 1. - * It's a limitation of the current implementation but a reasonable trade-off since features - * with large number of categories get favored over continuous features. - * - * This needs to be checked here instead of in Strategy since numBins can be determined - * by the number of training examples. - * TODO: Allow this case, where we simply will know nothing about some categories. - */ - if (metadata.featureArity.size > 0) { - val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2 - require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + - "in categorical features") - } - - // Calculate the number of sample for approximate quantile calculation. - val requiredSamples = numBins*numBins - val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 - logDebug("fraction of data used for calculating quantiles = " + fraction) + val numFeatures = metadata.numFeatures - // sampled input for RDD calculation - val sampledInput = + // Sample the input only if there are continuous features. + val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) + val sampledInput = if (hasContinuousFeatures) { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + val fraction = if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() - val numSamples = sampledInput.length - - val stride: Double = numSamples.toDouble / numBins - logDebug("stride = " + stride) + } else { + new Array[LabeledPoint](0) + } metadata.quantileStrategy match { case Sort => - val splits = Array.ofDim[Split](numFeatures, numBins - 1) - val bins = Array.ofDim[Bin](numFeatures, numBins) + val splits = new Array[Array[Split]](numFeatures) + val bins = new Array[Array[Bin]](numFeatures) // Find all splits. - // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Check whether the feature is continuous. - val isFeatureContinuous = metadata.isContinuous(featureIndex) - if (isFeatureContinuous) { + val numSplits = metadata.numSplits(featureIndex) + val numBins = metadata.numBins(featureIndex) + if (metadata.isContinuous(featureIndex)) { + val numSamples = sampledInput.length + splits(featureIndex) = new Array[Split](numSplits) + bins(featureIndex) = new Array[Bin](numBins) val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble / numBins + val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) logDebug("stride = " + stride) - for (index <- 0 until numBins - 1) { - val sampleIndex = index * stride.toInt + for (splitIndex <- 0 until numSplits) { + val sampleIndex = splitIndex * stride.toInt // Set threshold halfway in between 2 samples. val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 - val split = new Split(featureIndex, threshold, Continuous, List()) - splits(featureIndex)(index) = split + splits(featureIndex)(splitIndex) = + new Split(featureIndex, threshold, Continuous, List()) } - } else { // Categorical feature - val featureCategories = metadata.featureArity(featureIndex) - - // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint. + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) + for (splitIndex <- 1 until numSplits) { + bins(featureIndex)(splitIndex) = + new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), + Continuous, Double.MinValue) + } + bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) + } else { + // Categorical feature + val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { - // 2^(maxFeatureValue- 1) - 1 combinations - var index = 0 - while (index < (1 << featureCategories - 1) - 1) { - val categories: List[Double] - = extractMultiClassCategories(index + 1, featureCategories) - splits(featureIndex)(index) - = new Split(featureIndex, Double.MinValue, Categorical, categories) - bins(featureIndex)(index) = { - if (index == 0) { + // TODO: The second half of the bins are unused. Actually, we could just use + // splits and not build bins for unordered features. That should be part of + // a later PR since it will require changing other code (using splits instead + // of bins in a few places). + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + splits(featureIndex) = new Array[Split](numSplits) + bins(featureIndex) = new Array[Bin](numBins) + var splitIndex = 0 + while (splitIndex < numSplits) { + val categories: List[Double] = + extractMultiClassCategories(splitIndex + 1, featureArity) + splits(featureIndex)(splitIndex) = + new Split(featureIndex, Double.MinValue, Categorical, categories) + bins(featureIndex)(splitIndex) = { + if (splitIndex == 0) { new Bin( new DummyCategoricalSplit(featureIndex, Categorical), splits(featureIndex)(0), @@ -1395,96 +1060,24 @@ object DecisionTree extends Serializable with Logging { Double.MinValue) } else { new Bin( - splits(featureIndex)(index - 1), - splits(featureIndex)(index), + splits(featureIndex)(splitIndex - 1), + splits(featureIndex)(splitIndex), Categorical, Double.MinValue) } } - index += 1 - } - } else { // ordered feature - /* For a given categorical feature, use a subsample of the data - * to choose how to arrange possible splits. - * This examines each category and computes a centroid. - * These centroids are later used to sort the possible splits. - * centroidForCategories is a mapping: category (for the given feature) --> centroid - */ - val centroidForCategories = { - if (isMulticlass) { - // For categorical variables in multiclass classification, - // each bin is a category. The bins are sorted and they - // are ordered by calculating the impurity of their corresponding labels. - sampledInput.map(lp => (lp.features(featureIndex), lp.label)) - .groupBy(_._1) - .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) - .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum))) - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - sampledInput.map(lp => (lp.features(featureIndex), lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - } - } - - logDebug("centroid for categories = " + centroidForCategories.mkString(",")) - - // Check for missing categorical variables and putting them last in the sorted list. - val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until featureCategories) { - if (centroidForCategories.contains(i)) { - fullCentroidForCategories(i) = centroidForCategories(i) - } else { - fullCentroidForCategories(i) = Double.MaxValue - } - } - - // bins sorted by centroids - val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - - logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) - - var categoriesForSplit = List[Double]() - categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((key, value), index) => - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, - Categorical, categoriesForSplit) - bins(featureIndex)(index) = { - if (index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, key) - } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Categorical, key) - } - } + splitIndex += 1 } + } else { + // Ordered features + // Bins correspond to feature values, so we do not need to compute splits or bins + // beforehand. Splits are constructed as needed during training. + splits(featureIndex) = new Array[Split](0) + bins(featureIndex) = new Array[Bin](0) } } featureIndex += 1 } - - // Find all bins. - featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = metadata.isContinuous(featureIndex) - if (isFeatureContinuous) { // Bins for categorical variables are already assigned. - bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), - splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1) { - val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Continuous, Double.MinValue) - bins(featureIndex)(index) = bin - } - bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), - new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) - } - featureIndex += 1 - } (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index cfc8192a85abd..987fe632c91ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -49,8 +49,15 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. + * @param minInstancesPerNode Minimum number of instances each child must have after split. + * Default value is 1. If a split cause left or right child + * to have less than minInstancesPerNode, + * this split will not be considered as a valid split. + * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. + * If a split has less information gain than minInfoGain, + * this split will not be considered as a valid split. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is - * 128 MB. + * 256 MB. */ @Experimental class Strategy ( @@ -58,10 +65,12 @@ class Strategy ( val impurity: Impurity, val maxDepth: Int, val numClassesForClassification: Int = 2, - val maxBins: Int = 100, + val maxBins: Int = 32, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable { + val minInstancesPerNode: Int = 1, + val minInfoGain: Double = 0.0, + val maxMemoryInMB: Int = 256) extends Serializable { if (algo == Classification) { require(numClassesForClassification >= 2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala new file mode 100644 index 0000000000000..866d85a79bea1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.apache.spark.mllib.tree.impurity._ + +/** + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + */ +private[tree] class DTStatsAggregator( + val metadata: DecisionTreeMetadata, + val numNodes: Int) extends Serializable { + + /** + * [[ImpurityAggregator]] instance specifying the impurity type. + */ + val impurityAggregator: ImpurityAggregator = metadata.impurity match { + case Gini => new GiniAggregator(metadata.numClasses) + case Entropy => new EntropyAggregator(metadata.numClasses) + case Variance => new VarianceAggregator() + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + + /** + * Number of elements (Double values) used for the sufficient statistics of each bin. + */ + val statsSize: Int = impurityAggregator.statsSize + + val numFeatures: Int = metadata.numFeatures + + /** + * Number of bins for each feature. This is indexed by the feature index. + */ + val numBins: Array[Int] = metadata.numBins + + /** + * Number of splits for the given feature. + */ + def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex) + + /** + * Indicator for each feature of whether that feature is an unordered feature. + * TODO: Is Array[Boolean] any faster? + */ + def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + */ + private val featureOffsets: Array[Int] = { + def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { + if (isUnordered(featureIndex)) { + total + 2 * numBins(featureIndex) + } else { + total + numBins(featureIndex) + } + } + Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray + } + + /** + * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. + */ + private val nodeStride: Int = featureOffsets.last + + /** + * Total number of elements stored in this aggregator. + */ + val allStatsSize: Int = numNodes * nodeStride + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex)) + * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex)) + */ + val allStats: Array[Double] = new Array[Double](allStatsSize) + + /** + * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + */ + def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = { + impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize) + } + + /** + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + */ + def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { + val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label) + } + + /** + * Pre-compute node offset for use with [[nodeUpdate]]. + */ + def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + + /** + * Faster version of [[update]]. + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. + */ + def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { + val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label) + } + + /** + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For ordered features only. + */ + def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + require(!isUnordered(featureIndex), + s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" + + s" for unordered feature $featureIndex.") + nodeIndex * nodeStride + featureOffsets(featureIndex) + } + + /** + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For unordered features only. + */ + def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { + require(isUnordered(featureIndex), + s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + + s" but was called for ordered feature $featureIndex.") + val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) + (baseOffset, baseOffset + numBins(featureIndex) * statsSize) + } + + /** + * Faster version of [[update]]. + * Update the stats for a given (node, feature, bin), using the given label. + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + */ + def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = { + impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label) + } + + /** + * For a given (node, feature), merge the stats for two bins. + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + * @param binIndex The other bin is merged into this bin. + * @param otherBinIndex This bin is not modified. + */ + def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { + impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize, + nodeFeatureOffset + otherBinIndex * statsSize) + } + + /** + * Merge this aggregator with another, and returns this aggregator. + * This method modifies this aggregator in-place. + */ + def merge(other: DTStatsAggregator): DTStatsAggregator = { + require(allStatsSize == other.allStatsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." + + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") + var i = 0 + // TODO: Test BLAS.axpy + while (i < allStatsSize) { + allStats(i) += other.allStats(i) + i += 1 + } + this + } + +} + +private[tree] object DTStatsAggregator extends Serializable { + + /** + * Combines two aggregates (modifying the first) and returns the combination. + */ + def binCombOp( + agg1: DTStatsAggregator, + agg2: DTStatsAggregator): DTStatsAggregator = { + agg1.merge(agg2) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index d9eda354dc986..5ceaa8154d11a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.rdd.RDD - /** * Learning and dataset metadata for DecisionTree. * * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. * For regression: fixed at 0 (no meaning). + * @param maxBins Maximum number of bins, for all features. * @param featureArity Map: categorical feature index --> arity. * I.e., the feature takes values in {0, ..., arity - 1}. + * @param numBins Number of bins for each feature. */ private[tree] class DecisionTreeMetadata( val numFeatures: Int, @@ -42,8 +43,11 @@ private[tree] class DecisionTreeMetadata( val maxBins: Int, val featureArity: Map[Int, Int], val unorderedFeatures: Set[Int], + val numBins: Array[Int], val impurity: Impurity, - val quantileStrategy: QuantileStrategy) extends Serializable { + val quantileStrategy: QuantileStrategy, + val minInstancesPerNode: Int, + val minInfoGain: Double) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -57,10 +61,26 @@ private[tree] class DecisionTreeMetadata( def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + /** + * Number of splits for the given feature. + * For unordered features, there are 2 bins per split. + * For ordered features, there is 1 more bin than split. + */ + def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { + numBins(featureIndex) >> 1 + } else { + numBins(featureIndex) - 1 + } + } private[tree] object DecisionTreeMetadata { + /** + * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. + * This computes which categorical features will be ordered vs. unordered, + * as well as the number of splits and bins for each feature. + */ def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { val numFeatures = input.take(1)(0).features.size @@ -70,32 +90,56 @@ private[tree] object DecisionTreeMetadata { case Regression => 0 } - val maxBins = math.min(strategy.maxBins, numExamples).toInt - val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0) + val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + + // We check the number of bins here against maxPossibleBins. + // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified + // based on the number of training examples. + if (strategy.categoricalFeaturesInfo.nonEmpty) { + val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + require(maxCategoriesPerFeature <= maxPossibleBins, + s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + + s"in categorical features (= $maxCategoriesPerFeature)") + } val unorderedFeatures = new mutable.HashSet[Int]() + val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) if (numClasses > 2) { - strategy.categoricalFeaturesInfo.foreach { case (f, k) => - if (k - 1 < log2MaxBinsp1) { - // Note: The above check is equivalent to checking: - // numUnorderedBins = (1 << k - 1) - 1 < maxBins - unorderedFeatures.add(f) + // Multiclass classification + val maxCategoriesForUnorderedFeature = + ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) } else { - // TODO: Allow this case, where we simply will know nothing about some categories? - require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + - s"in categorical features (>= $k)") + numBins(featureIndex) = numCategories } } } else { - strategy.categoricalFeaturesInfo.foreach { case (f, k) => - require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + - s"in categorical features (>= $k)") + // Binary classification or regression + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + numBins(featureIndex) = numCategories } } - new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins, - strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, - strategy.impurity, strategy.quantileCalculationStrategy) + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, + strategy.impurity, strategy.quantileCalculationStrategy, + strategy.minInstancesPerNode, strategy.minInfoGain) } + /** + * Given the arity of a categorical feature (arity = number of categories), + * return the number of bins for the feature if it is to be treated as an unordered feature. + * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; + * there are math.pow(2, arity - 1) - 1 such splits. + * Each split has 2 corresponding bins. + */ + def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 170e43e222083..35e361ae309cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -48,54 +48,63 @@ private[tree] object TreePoint { * binning feature values in preparation for DecisionTree training. * @param input Input dataset. * @param bins Bins for features, of size (numFeatures, numBins). - * @param metadata Learning and dataset metadata + * @param metadata Learning and dataset metadata * @return TreePoint dataset representation */ def convertToTreeRDD( input: RDD[LabeledPoint], bins: Array[Array[Bin]], metadata: DecisionTreeMetadata): RDD[TreePoint] = { + // Construct arrays for featureArity and isUnordered for efficiency in the inner loop. + val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) + val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures) + var featureIndex = 0 + while (featureIndex < metadata.numFeatures) { + featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) + isUnordered(featureIndex) = metadata.isUnordered(featureIndex) + featureIndex += 1 + } input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, metadata) + TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered) } } /** * Convert one LabeledPoint into its TreePoint representation. * @param bins Bins for features, of size (numFeatures, numBins). + * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories + * for categorical features. + * @param isUnordered Array index by feature, with value true for unordered categorical features. */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, bins: Array[Array[Bin]], - metadata: DecisionTreeMetadata): TreePoint = { - + featureArity: Array[Int], + isUnordered: Array[Boolean]): TreePoint = { val numFeatures = labeledPoint.features.size - val numBins = bins(0).size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex), - metadata.isUnordered(featureIndex), bins, metadata.featureArity) + arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), + isUnordered(featureIndex), bins) featureIndex += 1 } - new TreePoint(labeledPoint.label, arr) } /** * Find bin for one (labeledPoint, feature). * + * @param featureArity 0 for continuous features; number of categories for categorical features. * @param isUnorderedFeature (only applies if feature is categorical) * @param bins Bins for features, of size (numFeatures, numBins). - * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity */ private def findBin( featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, + featureArity: Int, isUnorderedFeature: Boolean, - bins: Array[Array[Bin]], - categoricalFeaturesInfo: Map[Int, Int]): Int = { + bins: Array[Array[Bin]]): Int = { /** * Binary search helper method for continuous feature. @@ -121,44 +130,7 @@ private[tree] object TreePoint { -1 } - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex - } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - -1 - } - - if (isFeatureContinuous) { + if (featureArity == 0) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() if (binIndex == -1) { @@ -168,18 +140,17 @@ private[tree] object TreePoint { } binIndex } else { - // Perform sequential search to find bin for categorical features. - val binIndex = if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - if (binIndex == -1) { - throw new RuntimeException("No bin was found for categorical feature." + - " This error can occur when given invalid data values (such as NaN)." + - s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + // Categorical feature bins are indexed by feature values. + val featureValue = labeledPoint.features(featureIndex) + if (featureValue < 0 || featureValue >= featureArity) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureArity - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) } - binIndex + featureValue.toInt } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 96d2471e1f88c..1c8afc2d0f4bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -74,3 +74,87 @@ object Entropy extends Impurity { def instance = this } + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + */ +private[tree] class EntropyAggregator(numClasses: Int) + extends ImpurityAggregator(numClasses) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + allStats(offset + label.toInt) += 1 + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { + new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) + } + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[EntropyAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: EntropyCalculator = new EntropyCalculator(stats.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = Entropy.calculate(stats, stats.sum) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats.sum.toLong + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + indexOfLargestArrayElement(stats) + } + + /** + * Probability of the label given by [[predict]]. + */ + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < stats.length, + s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + val cnt = count + if (cnt == 0) { + 0 + } else { + stats(lbl) / cnt + } + } + + override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])" + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index d586f449048bb..5cfdf345d163c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -70,3 +70,87 @@ object Gini extends Impurity { def instance = this } + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + */ +private[tree] class GiniAggregator(numClasses: Int) + extends ImpurityAggregator(numClasses) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + allStats(offset + label.toInt) += 1 + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { + new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) + } + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[GiniAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: GiniCalculator = new GiniCalculator(stats.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = Gini.calculate(stats, stats.sum) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats.sum.toLong + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + indexOfLargestArrayElement(stats) + } + + /** + * Probability of the label given by [[predict]]. + */ + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < stats.length, + s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + val cnt = count + if (cnt == 0) { + 0 + } else { + stats(lbl) / cnt + } + } + + override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])" + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 92b0c7b4a6fbc..5a047d6cb5480 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -22,6 +22,9 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} /** * :: Experimental :: * Trait for calculating information gain. + * This trait is used for + * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]] + * (b) calculating impurity values from sufficient statistics. */ @Experimental trait Impurity extends Serializable { @@ -47,3 +50,127 @@ trait Impurity extends Serializable { @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double } + +/** + * Interface for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param statsSize Length of the vector of sufficient statistics for one bin. + */ +private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { + + /** + * Merge the stats from one bin into another. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for (node, feature, bin) which is modified by the merge. + * @param otherOffset Start index of stats for (node, feature, other bin) which is not modified. + */ + def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = { + var i = 0 + while (i < statsSize) { + allStats(offset + i) += allStats(otherOffset + i) + i += 1 + } + } + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: ImpurityCalculator + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double + + /** + * Add the stats from another calculator into this one, modifying and returning this calculator. + */ + def add(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.size == other.stats.size, + s"Two ImpurityCalculator instances cannot be added with different counts sizes." + + s" Sizes are ${stats.size} and ${other.stats.size}.") + var i = 0 + while (i < other.stats.size) { + stats(i) += other.stats(i) + i += 1 + } + this + } + + /** + * Subtract the stats from another calculator from this one, modifying and returning this + * calculator. + */ + def subtract(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.size == other.stats.size, + s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + + s" Sizes are ${stats.size} and ${other.stats.size}.") + var i = 0 + while (i < other.stats.size) { + stats(i) -= other.stats(i) + i += 1 + } + this + } + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double + + /** + * Probability of the label given by [[predict]], or -1 if no probability is available. + */ + def prob(label: Double): Double = -1 + + /** + * Return the index of the largest array element. + * Fails if the array is empty. + */ + protected def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } + } + if (result._1 < 0) { + throw new RuntimeException("ImpurityCalculator internal error:" + + " indexOfLargestArrayElement failed") + } + result._1 + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index f7d99a40eb380..e9ccecb1b8067 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -61,3 +61,75 @@ object Variance extends Impurity { def instance = this } + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + */ +private[tree] class VarianceAggregator() + extends ImpurityAggregator(statsSize = 3) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + allStats(offset) += 1 + allStats(offset + 1) += label + allStats(offset + 2) += label * label + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { + new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) + } + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[GiniAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + require(stats.size == 3, + s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + + s" but was given array of length ${stats.size}.") + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: VarianceCalculator = new VarianceCalculator(stats.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2)) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats(0).toLong + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + stats(1) / count + } + + override def toString: String = { + s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})" + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index af35d88f713e5..0cad473782af1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ /** - * Used for "binning" the features bins for faster best split calculation. + * Used for "binning" the feature values for faster best split calculation. * * For a continuous feature, the bin is determined by a low and a high split, * where an example with featureValue falls into the bin s.t. @@ -30,13 +30,16 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * bins, splits, and feature values. The bin is determined by category/feature value. * However, the bins are not necessarily ordered by feature value; * they are ordered using impurity. + * * For unordered categorical features, there is a 1-1 correspondence between bins, splits, * where bins and splits correspond to subsets of feature values (in highSplit.categories). + * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all + * partitionings of categories into 2 disjoint, non-empty sets. * * @param lowSplit signifying the lower threshold for the continuous feature to be * accepted in the bin * @param highSplit signifying the upper threshold for the continuous feature to be - * accepted in the bin + * accepted in the bin * @param featureType type of feature -- categorical or continuous * @param category categorical label value accepted in the bin for ordered features */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index fb12298e0f5d3..f3e2619bd8ba0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,20 +26,26 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity - * @param predict predicted value - * @param prob probability of the label (classification only) */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double, - val predict: Double, - val prob: Double = 0.0) extends Serializable { + val rightImpurity: Double) extends Serializable { override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f" - .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" + .format(gain, impurity, leftImpurity, rightImpurity) } } + + +private[tree] object InformationGainStats { + /** + * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 0eee6262781c1..5b8a4cbed2306 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -24,8 +24,13 @@ import org.apache.spark.mllib.linalg.Vector /** * :: DeveloperApi :: - * Node in a decision tree - * @param id integer node id + * Node in a decision tree. + * + * About node indexing: + * Nodes are indexed from 1. Node 1 is the root; nodes 2, 3 are the left, right children. + * Node index 0 is not used. + * + * @param id integer node id, from 1 * @param predict predicted value at the node * @param isLeaf whether the leaf is a node * @param split split to calculate left and right nodes @@ -51,17 +56,13 @@ class Node ( * @param nodes array of nodes */ def build(nodes: Array[Node]): Unit = { - - logDebug("building node " + id + " at level " + - (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("building node " + id + " at level " + Node.indexToLevel(id)) logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) if (!isLeaf) { - val leftNodeIndex = id * 2 + 1 - val rightNodeIndex = id * 2 + 2 - leftNode = Some(nodes(leftNodeIndex)) - rightNode = Some(nodes(rightNodeIndex)) + leftNode = Some(nodes(Node.leftChildIndex(id))) + rightNode = Some(nodes(Node.rightChildIndex(id))) leftNode.get.build(nodes) rightNode.get.build(nodes) } @@ -96,24 +97,20 @@ class Node ( * Get the number of nodes in tree below this node, including leaf nodes. * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. */ - private[tree] def numDescendants: Int = { - if (isLeaf) { - 0 - } else { - 2 + leftNode.get.numDescendants + rightNode.get.numDescendants - } + private[tree] def numDescendants: Int = if (isLeaf) { + 0 + } else { + 2 + leftNode.get.numDescendants + rightNode.get.numDescendants } /** * Get depth of tree from this node. * E.g.: Depth 0 means this is a leaf node. */ - private[tree] def subtreeDepth: Int = { - if (isLeaf) { - 0 - } else { - 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) - } + private[tree] def subtreeDepth: Int = if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) } /** @@ -148,3 +145,49 @@ class Node ( } } + +private[tree] object Node { + + /** + * Return the index of the left child of this node. + */ + def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 + + /** + * Return the index of the right child of this node. + */ + def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 + + /** + * Get the parent index of the given node, or 0 if it is the root. + */ + def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 + + /** + * Return the level of a tree which the given node is in. + */ + def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { + throw new IllegalArgumentException(s"0 is not a valid node index.") + } else { + java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) + } + + /** + * Returns true if this is a left child. + * Note: Returns false for the root. + */ + def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 + + /** + * Return the maximum number of nodes which can be in the given level of the tree. + * @param level Level of tree (0 = root). + */ + def maxNodesInLevel(level: Int): Int = 1 << level + + /** + * Return the index of the first node in the given level. + * @param level Level of tree (0 = root). + */ + def startIndexInLevel(level: Int): Int = 1 << level + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala new file mode 100644 index 0000000000000..6fac2be2797bc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Predicted value for a node + * @param predict predicted value + * @param prob probability of the label (classification only) + */ +@DeveloperApi +private[tree] class Predict( + val predict: Double, + val prob: Double = 0.0) extends Serializable{ + + override def toString = { + "predict = %f, prob = %f".format(predict, prob) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 50fb48b40de3d..b7a85f58544a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType +import org.apache.spark.mllib.tree.configuration.FeatureType +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType /** * :: DeveloperApi :: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2f36fd907772c..fd8547c1660fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -21,15 +21,15 @@ import scala.collection.JavaConverters._ import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -59,12 +59,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") } - test("split and bin calculation") { + test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) @@ -72,7 +73,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) } - test("split and bin calculation for categorical variables") { + test("Binary classification with binary (ordered) categorical features:" + + " split and bin calculation") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -83,77 +85,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) assert(splits.length === 2) assert(bins.length === 2) - assert(splits(0).length === 99) - assert(bins(0).length === 100) - - // Check splits. - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(1.0)) - assert(splits(0)(1).categories.contains(0.0)) - - assert(splits(0)(2) === null) - - assert(splits(1)(0).feature === 1) - assert(splits(1)(0).threshold === Double.MinValue) - assert(splits(1)(0).featureType === Categorical) - assert(splits(1)(0).categories.length === 1) - assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(1)(1).feature === 1) - assert(splits(1)(1).threshold === Double.MinValue) - assert(splits(1)(1).featureType === Categorical) - assert(splits(1)(1).categories.length === 2) - assert(splits(1)(1).categories.contains(1.0)) - assert(splits(1)(1).categories.contains(0.0)) - - assert(splits(1)(2) === null) - - // Check bins. - - assert(bins(0)(0).category === 1.0) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) - - assert(bins(0)(1).category === 0.0) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.contains(0.0)) - - assert(bins(0)(2) === null) - - assert(bins(1)(0).category === 0.0) - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) - - assert(bins(1)(1).category === 1.0) - assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 2) - assert(bins(1)(1).highSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(1)(2) === null) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } - test("split and bin calculations for categorical variables with no sample for one category") { + test("Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -164,104 +109,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - - // Check splits. - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(1.0)) - assert(splits(0)(1).categories.contains(0.0)) - - assert(splits(0)(2).feature === 0) - assert(splits(0)(2).threshold === Double.MinValue) - assert(splits(0)(2).featureType === Categorical) - assert(splits(0)(2).categories.length === 3) - assert(splits(0)(2).categories.contains(1.0)) - assert(splits(0)(2).categories.contains(0.0)) - assert(splits(0)(2).categories.contains(2.0)) - - assert(splits(0)(3) === null) - - assert(splits(1)(0).feature === 1) - assert(splits(1)(0).threshold === Double.MinValue) - assert(splits(1)(0).featureType === Categorical) - assert(splits(1)(0).categories.length === 1) - assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(1)(1).feature === 1) - assert(splits(1)(1).threshold === Double.MinValue) - assert(splits(1)(1).featureType === Categorical) - assert(splits(1)(1).categories.length === 2) - assert(splits(1)(1).categories.contains(1.0)) - assert(splits(1)(1).categories.contains(0.0)) - - assert(splits(1)(2).feature === 1) - assert(splits(1)(2).threshold === Double.MinValue) - assert(splits(1)(2).featureType === Categorical) - assert(splits(1)(2).categories.length === 3) - assert(splits(1)(2).categories.contains(1.0)) - assert(splits(1)(2).categories.contains(0.0)) - assert(splits(1)(2).categories.contains(2.0)) - - assert(splits(1)(3) === null) - - // Check bins. - - assert(bins(0)(0).category === 1.0) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) - - assert(bins(0)(1).category === 0.0) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.contains(0.0)) - - assert(bins(0)(2).category === 2.0) - assert(bins(0)(2).lowSplit.categories.length === 2) - assert(bins(0)(2).lowSplit.categories.contains(1.0)) - assert(bins(0)(2).lowSplit.categories.contains(0.0)) - assert(bins(0)(2).highSplit.categories.length === 3) - assert(bins(0)(2).highSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.contains(0.0)) - assert(bins(0)(2).highSplit.categories.contains(2.0)) - - assert(bins(0)(3) === null) - - assert(bins(1)(0).category === 0.0) - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) - - assert(bins(1)(1).category === 1.0) - assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 2) - assert(bins(1)(1).highSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(1)(2).category === 2.0) - assert(bins(1)(2).lowSplit.categories.length === 2) - assert(bins(1)(2).lowSplit.categories.contains(0.0)) - assert(bins(1)(2).lowSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.length === 3) - assert(bins(1)(2).highSplit.categories.contains(0.0)) - assert(bins(1)(2).highSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.contains(2.0)) - - assert(bins(1)(3) === null) + assert(splits.length === 2) + assert(bins.length === 2) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } test("extract categories from a number for multiclass classification") { @@ -270,8 +127,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } - test("split and bin calculations for unordered categorical variables with multiclass " + - "classification") { + test("Multiclass classification with unordered categorical features:" + + " split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -282,8 +139,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 3) + assert(bins(0).length === 6) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -321,10 +185,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(1)(2).categories.contains(0.0)) assert(splits(1)(2).categories.contains(1.0)) - assert(splits(0)(3) === null) - assert(splits(1)(3) === null) - - // Check bins. assert(bins(0)(0).category === Double.MinValue) @@ -360,13 +220,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(1)(2).highSplit.categories.contains(1.0)) assert(bins(1)(2).highSplit.categories.contains(0.0)) - assert(bins(0)(3) === null) - assert(bins(1)(3) === null) - } - test("split and bin calculations for ordered categorical variables with multiclass " + - "classification") { + test("Multiclass classification with ordered categorical features: split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() assert(arr.length === 3000) val rdd = sc.parallelize(arr) @@ -377,52 +233,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + // 2^10 - 1 > 100, so categorical features will be ordered + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - - // 2^10 - 1 > 100, so categorical variables will be ordered - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(2.0)) - - assert(splits(0)(2).feature === 0) - assert(splits(0)(2).threshold === Double.MinValue) - assert(splits(0)(2).featureType === Categorical) - assert(splits(0)(2).categories.length === 3) - assert(splits(0)(2).categories.contains(2.0)) - assert(splits(0)(2).categories.contains(1.0)) - - assert(splits(0)(10) === null) - assert(splits(1)(10) === null) - - - // Check bins. - - assert(bins(0)(0).category === 1.0) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) - assert(bins(0)(1).category === 2.0) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.contains(2.0)) - - assert(bins(0)(10) === null) - + assert(splits.length === 2) + assert(bins.length === 2) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } - test("classification stump with all categorical variables") { + test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -433,26 +258,35 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(bins.length === 2) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 - assert(split.categories.length === 1) - assert(split.categories.contains(1.0)) + assert(split.categories === List(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 + val predict = bestSplits(0)._3 assert(stats.gain > 0) - assert(stats.predict === 1) - assert(stats.prob === 0.6) + assert(predict.predict === 1) + assert(predict.prob === 0.6) assert(stats.impurity > 0.2) } - test("regression stump with all categorical variables") { + test("Regression stump with 3-ary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -462,10 +296,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 @@ -475,12 +313,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 + val predict = bestSplits(0)._3.predict assert(stats.gain > 0) - assert(stats.predict === 0.6) + assert(predict === 0.6) assert(stats.impurity > 0.2) } - test("regression stump with categorical variables of arity 2") { + test("Regression stump with binary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -490,6 +329,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) validateRegressor(model, arr, 0.0) @@ -497,22 +339,24 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) } - test("stump with fixed label 0 for Gini") { + test("Binary classification stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) + val strategy = new Strategy(Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) assert(bins(0).length === 100) - assert(splits(0).length === 99) - assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -521,82 +365,88 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.rightImpurity === 0) } - test("stump with fixed label 1 for Gini") { + test("Binary classification stump with fixed label 1 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) + val strategy = new Strategy(Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) assert(bins(0).length === 100) - assert(splits(0).length === 99) - assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 1) + assert(bestSplits(0)._3.predict === 1) } - test("stump with fixed label 0 for Entropy") { + test("Binary classification stump with fixed label 0 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val strategy = new Strategy(Classification, Entropy, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) assert(bins(0).length === 100) - assert(splits(0).length === 99) - assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 0) + assert(bestSplits(0)._3.predict === 0) } - test("stump with fixed label 1 for Entropy") { + test("Binary classification stump with fixed label 1 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val strategy = new Strategy(Classification, Entropy, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) assert(bins(0).length === 100) - assert(splits(0).length === 99) - assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 1) + assert(bestSplits(0)._3.predict === 1) } - test("second level node building with/without groups") { + test("Second level node building with vs. without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -607,18 +457,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins.length === 2) assert(bins(0).length === 100) - assert(splits(0).length === 99) - assert(bins(0).length === 100) // Train a 1-node model val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val nodes: Array[Node] = new Array[Node](7) - nodes(0) = modelOneNode.topNode - nodes(0).leftNode = None - nodes(0).rightNode = None + val nodes: Array[Node] = new Array[Node](8) + nodes(1) = modelOneNode.topNode + nodes(1).leftNode = None + nodes(1).rightNode = None - val parentImpurities = Array(0.5, 0.5, 0.5) + val parentImpurities = Array(0, 0.5, 0.5, 0.5) // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) @@ -644,20 +492,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) - assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict) } } - test("stump with categorical variables for multiclass classification") { + test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -668,7 +519,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } - test("stump with 1 continuous variable for binary classification, to check off-by-1 error") { + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) @@ -684,26 +535,27 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) } - test("stump with 2 continuous variables for binary classification") { + test("Binary classification stump with 2 continuous features") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) assert(model.topNode.split.get.feature === 1) } - test("stump with categorical variables for multiclass classification, with just enough bins") { - val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -711,6 +563,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) @@ -719,7 +573,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -733,11 +587,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(gain.rightImpurity === 0) } - test("stump with continuous variables for multiclass classification") { + test("Multiclass classification stump with continuous features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3) + numClassesForClassification = 3, maxBins = 100) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) @@ -746,7 +600,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -759,20 +613,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("stump with continuous + categorical variables for multiclass classification") { + test("Multiclass classification stump with continuous + unordered categorical features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) + numClassesForClassification = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -784,17 +639,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.threshold < 2020) } - test("stump with categorical variables for ordered multiclass classification") { + test("Multiclass classification stump with 10-ary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + numClassesForClassification = 3, maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -805,7 +663,104 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 3, maxBins = 10, + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(rdd, strategy) + validateClassifier(model, arr, 0.6) + } + + test("split must satisfy min instances per node requirements") { + val arr = new Array[LabeledPoint](3) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, + maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2) + val model = DecisionTree.train(input, strategy) + assert(model.topNode.isLeaf) + assert(model.topNode.predict == 0.0) + val predicts = input.map(p => model.predict(p.features)).collect() + predicts.foreach { predict => + assert(predict == 0.0) + } + + // test for findBestSplits when no valid split can be found + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, + new Array[Node](0), splits, bins, 10) + + assert(bestSplits.length == 1) + val bestInfoStats = bestSplits(0)._2 + assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) + } + + test("don't choose split that doesn't satisfy min instance per node requirements") { + // if a split doesn't satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + numClassesForClassification = 2, minInstancesPerNode = 2) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, + new Array[Node](0), splits, bins, 10) + + assert(bestSplits.length == 1) + val bestSplit = bestSplits(0)._1 + val bestSplitStats = bestSplits(0)._1 + assert(bestSplit.feature == 1) + assert(bestSplitStats != InformationGainStats.invalidInformationGainStats) + } + + test("split must satisfy min info gain requirements") { + val arr = new Array[LabeledPoint](3) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, minInfoGain = 1.0) + + val model = DecisionTree.train(input, strategy) + assert(model.topNode.isLeaf) + assert(model.topNode.predict == 0.0) + val predicts = input.map(p => model.predict(p.features)).collect() + predicts.foreach { predict => + assert(predict == 0.0) + } + + // test for findBestSplits when no valid split can be found + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, + new Array[Node](0), splits, bins, 10) + + assert(bestSplits.length == 1) + val bestInfoStats = bestSplits(0)._2 + assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) + } } object DecisionTreeSuite { @@ -899,5 +854,4 @@ object DecisionTreeSuite { arr } - } diff --git a/pom.xml b/pom.xml index 0d44cf4ea5f92..28763476f8313 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -125,6 +125,7 @@ 2.4.1 ${hadoop.version} 0.94.6 + 1.4.0 3.4.5 0.12.0 1.4.3 @@ -220,6 +221,18 @@ false + + + spark-staging-1030 + Spark 1.1.0 Staging (1030) + https://repository.apache.org/content/repositories/orgapachespark-1030/ + + true + + + false + + @@ -260,6 +273,7 @@ com.google.guava guava 14.0.1 + provided org.apache.commons @@ -825,7 +839,6 @@ -unchecked -deprecation -feature - -language:postfixOps -Xms1024m @@ -874,17 +887,18 @@ org.scalatest scalatest-maven-plugin - 1.0-RC2 + 1.0 ${project.build.directory}/surefire-reports . - ${project.build.directory}/SparkTestSuite.txt + SparkTestSuite.txt -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m true ${session.executionRootDirectory} 1 + false @@ -1017,6 +1031,21 @@ + + + sbt + + + com.google.guava + guava + compile + + + + spark-ganglia-lgpl @@ -1115,18 +1144,49 @@ - mapr + mapr3 false 1.0.3-mapr-3.0.3 - 2.3.0-mapr-4.0.0-beta - 0.94.17-mapr-1403 - 3.4.5-mapr-1401 + 2.3.0-mapr-4.0.0-FCS + 0.94.17-mapr-1405 + 3.4.5-mapr-1406 + + mapr4 + + false + + + 2.3.0-mapr-4.0.0-FCS + 2.3.0-mapr-4.0.0-FCS + 0.94.17-mapr-1405-4.0.0-FCS + 3.4.5-mapr-1406 + + + + org.apache.curator + curator-recipes + 2.4.0 + + + org.apache.zookeeper + zookeeper + + + + + org.apache.zookeeper + zookeeper + 3.4.5-mapr-1406 + + + + hadoop-provided @@ -1179,7 +1239,7 @@ - hive-thriftserver + hive false diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 034ba6a7bf50f..0f5d71afcf616 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -85,7 +85,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.0.0" + val previousSparkVersion = "1.1.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 300589394b96f..46b78bd5c7061 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -33,6 +33,18 @@ import com.typesafe.tools.mima.core._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.2") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + // This is @DeveloperAPI, but Mima still gives false-positives: + MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ + Seq( + // This is @Experimental, but Mima still gives false-positives: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync") + ) case v if v.startsWith("1.1") => Seq( MimaBuild.excludeSparkPackage("deploy"), @@ -41,6 +53,9 @@ object MimaExcludes { Seq( // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), + // Should probably mark this as Experimental + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values // for countApproxDistinct* functions, which does not work in Java. We later removed // them, and use the following to tell Mima to not care about them. @@ -58,6 +73,8 @@ object MimaExcludes { "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.DiskStore.getValues"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.storage.MemoryStore.Entry") ) ++ @@ -106,6 +123,8 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Values") ++ MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + // Class was missing "@DeveloperApi" annotation in 1.0. + MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ Seq( ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Gini.calculate"), @@ -114,14 +133,14 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") ) ++ - Seq ( // Package-private classes removed in SPARK-2341 + Seq( // Package-private classes removed in SPARK-2341 ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") - ) ++ + ) ++ Seq( // package-private classes removed in MLlib ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 49d52aefca17a..c07ea313f1228 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -61,7 +61,7 @@ object SparkBuild extends PomBuild { def backwardCompatibility = { import scala.collection.mutable var isAlphaYarn = false - var profiles: mutable.Seq[String] = mutable.Seq.empty + var profiles: mutable.Seq[String] = mutable.Seq("sbt") if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.") profiles ++= Seq("spark-ganglia-lgpl") @@ -116,7 +116,7 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - + resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { @@ -184,7 +184,7 @@ object OldDeps { def versionArtifact(id: String): Option[sbt.ModuleID] = { val fullId = id + "_2.10" - Some("org.apache.spark" % fullId % "1.0.0") + Some("org.apache.spark" % fullId % "1.1.0") } def oldDepsSettings() = Defaults.defaultSettings ++ Seq( @@ -290,9 +290,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { @@ -314,7 +314,7 @@ object Unidoc { "-group", "Core Java API", packageList("api.java", "api.java.function"), "-group", "Spark Streaming", packageList( "streaming.api.java", "streaming.flume", "streaming.kafka", - "streaming.mqtt", "streaming.twitter", "streaming.zeromq" + "streaming.mqtt", "streaming.twitter", "streaming.zeromq", "streaming.kinesis" ), "-group", "MLlib", packageList( "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg", @@ -337,7 +337,7 @@ object TestSettings { javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.ports.maxRetries=100", - javaOptions in Test += "-Dspark.ui.port=0", + javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/project/plugins.sbt b/project/plugins.sbt index 2a61f56c2ea60..8096c61414660 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -26,3 +26,7 @@ addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") + +libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" + +libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c58555fc9d2c5..1a2e774738fe7 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -61,13 +61,17 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext -from pyspark.sql import SQLContext from pyspark.rdd import RDD -from pyspark.sql import SchemaRDD -from pyspark.sql import Row from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel +from pyspark.accumulators import Accumulator, AccumulatorParam +from pyspark.broadcast import Broadcast +from pyspark.serializers import MarshalSerializer, PickleSerializer +# for back compatibility +from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row -__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD", - "SparkFiles", "StorageLevel", "Row"] +__all__ = [ + "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", + "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", +] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index f133cf6f7befc..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -94,6 +94,9 @@ from pyspark.serializers import read_int, PickleSerializer +__all__ = ['Accumulator', 'AccumulatorParam'] + + pickleSer = PickleSerializer() # Holds accumulators registered on the current machine, keyed by ID. This is then used to send diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 675a2fcd2ff4e..5c7c9cc161dff 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -31,6 +31,10 @@ from pyspark.serializers import CompressedSerializer, PickleSerializer + +__all__ = ['Broadcast'] + + # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} @@ -59,11 +63,20 @@ def __init__(self, bid, value, java_broadcast=None, """ self.bid = bid if path is None: - self.value = value + self._value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry self.path = path + @property + def value(self): + """ Return the broadcasted value + """ + if not hasattr(self, "_value") and self.path is not None: + ser = CompressedSerializer(PickleSerializer()) + self._value = ser.load_stream(open(self.path)).next() + return self._value + def unpersist(self, blocking=False): self._jbroadcast.unpersist(blocking) os.unlink(self.path) @@ -72,15 +85,6 @@ def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) - def __getattr__(self, item): - if item == 'value' and self.path is not None: - ser = CompressedSerializer(PickleSerializer()) - value = ser.load_stream(open(self.path)).next() - self.value = value - return value - - raise AttributeError(item) - if __name__ == "__main__": import doctest diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 68062483dedaa..80e51d1a583a0 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -657,7 +657,6 @@ def save_partial(self, obj): def save_file(self, obj): """Save a file""" import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute - from ..transport.adapter import SerializingAdapter if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") @@ -691,13 +690,10 @@ def save_file(self, obj): tmpfile.close() if tst != '': raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) - elif fsize > SerializingAdapter.max_transmit_data: - raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % - (name,SerializingAdapter.max_transmit_data)) else: try: tmpfile = file(name) - contents = tmpfile.read(SerializingAdapter.max_transmit_data) + contents = tmpfile.read() tmpfile.close() except IOError: raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index fb716f6753a45..b64875a3f495a 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -54,6 +54,8 @@ (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] """ +__all__ = ['SparkConf'] + class SparkConf(object): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a90870ed3a353..3ab98e262df31 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -37,6 +37,9 @@ from py4j.java_collections import ListConverter +__all__ = ['SparkContext'] + + # These are special default configs for PySpark, they will overwrite # the default ones for Spark if they are not configured by user. DEFAULT_CONFIGS = { @@ -229,6 +232,20 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + def __enter__(self): + """ + Enable 'with SparkContext(...) as sc: app(sc)' syntax. + """ + return self + + def __exit__(self, type, value, trace): + """ + Enable 'with SparkContext(...) as sc: app' syntax. + + Specifically stop the context on exit of the with block. + """ + self.stop() + @classmethod def setSystemProperty(cls, key, value): """ @@ -314,12 +331,16 @@ def pickleFile(self, name, minPartitions=None): return RDD(self._jsc.objectFile(name, minPartitions), self, BatchedSerializer(PickleSerializer())) - def textFile(self, name, minPartitions=None): + def textFile(self, name, minPartitions=None, use_unicode=True): """ Read a text file from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI, and return it as an RDD of Strings. + If use_unicode is False, the strings will be kept as `str` (encoding + as `utf-8`), which is faster and smaller than unicode. (Added in + Spark 1.2) + >>> path = os.path.join(tempdir, "sample-text.txt") >>> with open(path, "w") as testFile: ... testFile.write("Hello world!") @@ -329,9 +350,9 @@ def textFile(self, name, minPartitions=None): """ minPartitions = minPartitions or min(self.defaultParallelism, 2) return RDD(self._jsc.textFile(name, minPartitions), self, - UTF8Deserializer()) + UTF8Deserializer(use_unicode)) - def wholeTextFiles(self, path, minPartitions=None): + def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): """ Read a directory of text files from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system @@ -339,6 +360,10 @@ def wholeTextFiles(self, path, minPartitions=None): key-value pair, where the key is the path of each file, the value is the content of each file. + If use_unicode is False, the strings will be kept as `str` (encoding + as `utf-8`), which is faster and smaller than unicode. (Added in + Spark 1.2) + For example, if you have the following files:: hdfs://a-hdfs-path/part-00000 @@ -369,7 +394,7 @@ def wholeTextFiles(self, path, minPartitions=None): """ minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, - PairDeserializer(UTF8Deserializer(), UTF8Deserializer())) + PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode))) def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() @@ -568,8 +593,6 @@ def broadcast(self, value): L{Broadcast} object for reading it in distributed functions. The variable will be sent to each cluster only once. - - :keep: Keep the `value` in driver or not. """ ser = CompressedSerializer(PickleSerializer()) # pass large object by py4j is very slow and need much memory @@ -608,8 +631,8 @@ def addFile(self, path): FTP URI. To access the file in Spark jobs, use - L{SparkFiles.get(path)} to find its - download location. + L{SparkFiles.get(fileName)} with the + filename to find its download location. >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 331de9a9b2212..797573f49dac8 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -18,6 +18,9 @@ import os +__all__ = ['SparkFiles'] + + class SparkFiles(object): """ diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py new file mode 100644 index 0000000000000..bc441f138f7fc --- /dev/null +++ b/python/pyspark/heapq3.py @@ -0,0 +1,890 @@ +# -*- encoding: utf-8 -*- +# back ported from CPython 3 +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +"""Heap queue algorithm (a.k.a. priority queue). + +Heaps are arrays for which a[k] <= a[2*k+1] and a[k] <= a[2*k+2] for +all k, counting elements from 0. For the sake of comparison, +non-existing elements are considered to be infinite. The interesting +property of a heap is that a[0] is always its smallest element. + +Usage: + +heap = [] # creates an empty heap +heappush(heap, item) # pushes a new item on the heap +item = heappop(heap) # pops the smallest item from the heap +item = heap[0] # smallest item on the heap without popping it +heapify(x) # transforms list into a heap, in-place, in linear time +item = heapreplace(heap, item) # pops and returns smallest item, and adds + # new item; the heap size is unchanged + +Our API differs from textbook heap algorithms as follows: + +- We use 0-based indexing. This makes the relationship between the + index for a node and the indexes for its children slightly less + obvious, but is more suitable since Python uses 0-based indexing. + +- Our heappop() method returns the smallest item, not the largest. + +These two make it possible to view the heap as a regular Python list +without surprises: heap[0] is the smallest item, and heap.sort() +maintains the heap invariant! +""" + +# Original code by Kevin O'Connor, augmented by Tim Peters and Raymond Hettinger + +__about__ = """Heap queues + +[explanation by François Pinard] + +Heaps are arrays for which a[k] <= a[2*k+1] and a[k] <= a[2*k+2] for +all k, counting elements from 0. For the sake of comparison, +non-existing elements are considered to be infinite. The interesting +property of a heap is that a[0] is always its smallest element. + +The strange invariant above is meant to be an efficient memory +representation for a tournament. The numbers below are `k', not a[k]: + + 0 + + 1 2 + + 3 4 5 6 + + 7 8 9 10 11 12 13 14 + + 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 + + +In the tree above, each cell `k' is topping `2*k+1' and `2*k+2'. In +an usual binary tournament we see in sports, each cell is the winner +over the two cells it tops, and we can trace the winner down the tree +to see all opponents s/he had. However, in many computer applications +of such tournaments, we do not need to trace the history of a winner. +To be more memory efficient, when a winner is promoted, we try to +replace it by something else at a lower level, and the rule becomes +that a cell and the two cells it tops contain three different items, +but the top cell "wins" over the two topped cells. + +If this heap invariant is protected at all time, index 0 is clearly +the overall winner. The simplest algorithmic way to remove it and +find the "next" winner is to move some loser (let's say cell 30 in the +diagram above) into the 0 position, and then percolate this new 0 down +the tree, exchanging values, until the invariant is re-established. +This is clearly logarithmic on the total number of items in the tree. +By iterating over all items, you get an O(n ln n) sort. + +A nice feature of this sort is that you can efficiently insert new +items while the sort is going on, provided that the inserted items are +not "better" than the last 0'th element you extracted. This is +especially useful in simulation contexts, where the tree holds all +incoming events, and the "win" condition means the smallest scheduled +time. When an event schedule other events for execution, they are +scheduled into the future, so they can easily go into the heap. So, a +heap is a good structure for implementing schedulers (this is what I +used for my MIDI sequencer :-). + +Various structures for implementing schedulers have been extensively +studied, and heaps are good for this, as they are reasonably speedy, +the speed is almost constant, and the worst case is not much different +than the average case. However, there are other representations which +are more efficient overall, yet the worst cases might be terrible. + +Heaps are also very useful in big disk sorts. You most probably all +know that a big sort implies producing "runs" (which are pre-sorted +sequences, which size is usually related to the amount of CPU memory), +followed by a merging passes for these runs, which merging is often +very cleverly organised[1]. It is very important that the initial +sort produces the longest runs possible. Tournaments are a good way +to that. If, using all the memory available to hold a tournament, you +replace and percolate items that happen to fit the current run, you'll +produce runs which are twice the size of the memory for random input, +and much better for input fuzzily ordered. + +Moreover, if you output the 0'th item on disk and get an input which +may not fit in the current tournament (because the value "wins" over +the last output value), it cannot fit in the heap, so the size of the +heap decreases. The freed memory could be cleverly reused immediately +for progressively building a second heap, which grows at exactly the +same rate the first heap is melting. When the first heap completely +vanishes, you switch heaps and start a new run. Clever and quite +effective! + +In a word, heaps are useful memory structures to know. I use them in +a few applications, and I think it is good to keep a `heap' module +around. :-) + +-------------------- +[1] The disk balancing algorithms which are current, nowadays, are +more annoying than clever, and this is a consequence of the seeking +capabilities of the disks. On devices which cannot seek, like big +tape drives, the story was quite different, and one had to be very +clever to ensure (far in advance) that each tape movement will be the +most effective possible (that is, will best participate at +"progressing" the merge). Some tapes were even able to read +backwards, and this was also used to avoid the rewinding time. +Believe me, real good tape sorts were quite spectacular to watch! +From all times, sorting has always been a Great Art! :-) +""" + +__all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge', + 'nlargest', 'nsmallest', 'heappushpop'] + +def heappush(heap, item): + """Push item onto heap, maintaining the heap invariant.""" + heap.append(item) + _siftdown(heap, 0, len(heap)-1) + +def heappop(heap): + """Pop the smallest item off the heap, maintaining the heap invariant.""" + lastelt = heap.pop() # raises appropriate IndexError if heap is empty + if heap: + returnitem = heap[0] + heap[0] = lastelt + _siftup(heap, 0) + return returnitem + return lastelt + +def heapreplace(heap, item): + """Pop and return the current smallest value, and add the new item. + + This is more efficient than heappop() followed by heappush(), and can be + more appropriate when using a fixed-size heap. Note that the value + returned may be larger than item! That constrains reasonable uses of + this routine unless written as part of a conditional replacement: + + if item > heap[0]: + item = heapreplace(heap, item) + """ + returnitem = heap[0] # raises appropriate IndexError if heap is empty + heap[0] = item + _siftup(heap, 0) + return returnitem + +def heappushpop(heap, item): + """Fast version of a heappush followed by a heappop.""" + if heap and heap[0] < item: + item, heap[0] = heap[0], item + _siftup(heap, 0) + return item + +def heapify(x): + """Transform list into a heap, in-place, in O(len(x)) time.""" + n = len(x) + # Transform bottom-up. The largest index there's any point to looking at + # is the largest with a child index in-range, so must have 2*i + 1 < n, + # or i < (n-1)/2. If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so + # j-1 is the largest, which is n//2 - 1. If n is odd = 2*j+1, this is + # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1. + for i in reversed(range(n//2)): + _siftup(x, i) + +def _heappop_max(heap): + """Maxheap version of a heappop.""" + lastelt = heap.pop() # raises appropriate IndexError if heap is empty + if heap: + returnitem = heap[0] + heap[0] = lastelt + _siftup_max(heap, 0) + return returnitem + return lastelt + +def _heapreplace_max(heap, item): + """Maxheap version of a heappop followed by a heappush.""" + returnitem = heap[0] # raises appropriate IndexError if heap is empty + heap[0] = item + _siftup_max(heap, 0) + return returnitem + +def _heapify_max(x): + """Transform list into a maxheap, in-place, in O(len(x)) time.""" + n = len(x) + for i in reversed(range(n//2)): + _siftup_max(x, i) + +# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos +# is the index of a leaf with a possibly out-of-order value. Restore the +# heap invariant. +def _siftdown(heap, startpos, pos): + newitem = heap[pos] + # Follow the path to the root, moving parents down until finding a place + # newitem fits. + while pos > startpos: + parentpos = (pos - 1) >> 1 + parent = heap[parentpos] + if newitem < parent: + heap[pos] = parent + pos = parentpos + continue + break + heap[pos] = newitem + +# The child indices of heap index pos are already heaps, and we want to make +# a heap at index pos too. We do this by bubbling the smaller child of +# pos up (and so on with that child's children, etc) until hitting a leaf, +# then using _siftdown to move the oddball originally at index pos into place. +# +# We *could* break out of the loop as soon as we find a pos where newitem <= +# both its children, but turns out that's not a good idea, and despite that +# many books write the algorithm that way. During a heap pop, the last array +# element is sifted in, and that tends to be large, so that comparing it +# against values starting from the root usually doesn't pay (= usually doesn't +# get us out of the loop early). See Knuth, Volume 3, where this is +# explained and quantified in an exercise. +# +# Cutting the # of comparisons is important, since these routines have no +# way to extract "the priority" from an array element, so that intelligence +# is likely to be hiding in custom comparison methods, or in array elements +# storing (priority, record) tuples. Comparisons are thus potentially +# expensive. +# +# On random arrays of length 1000, making this change cut the number of +# comparisons made by heapify() a little, and those made by exhaustive +# heappop() a lot, in accord with theory. Here are typical results from 3 +# runs (3 just to demonstrate how small the variance is): +# +# Compares needed by heapify Compares needed by 1000 heappops +# -------------------------- -------------------------------- +# 1837 cut to 1663 14996 cut to 8680 +# 1855 cut to 1659 14966 cut to 8678 +# 1847 cut to 1660 15024 cut to 8703 +# +# Building the heap by using heappush() 1000 times instead required +# 2198, 2148, and 2219 compares: heapify() is more efficient, when +# you can use it. +# +# The total compares needed by list.sort() on the same lists were 8627, +# 8627, and 8632 (this should be compared to the sum of heapify() and +# heappop() compares): list.sort() is (unsurprisingly!) more efficient +# for sorting. + +def _siftup(heap, pos): + endpos = len(heap) + startpos = pos + newitem = heap[pos] + # Bubble up the smaller child until hitting a leaf. + childpos = 2*pos + 1 # leftmost child position + while childpos < endpos: + # Set childpos to index of smaller child. + rightpos = childpos + 1 + if rightpos < endpos and not heap[childpos] < heap[rightpos]: + childpos = rightpos + # Move the smaller child up. + heap[pos] = heap[childpos] + pos = childpos + childpos = 2*pos + 1 + # The leaf at pos is empty now. Put newitem there, and bubble it up + # to its final resting place (by sifting its parents down). + heap[pos] = newitem + _siftdown(heap, startpos, pos) + +def _siftdown_max(heap, startpos, pos): + 'Maxheap variant of _siftdown' + newitem = heap[pos] + # Follow the path to the root, moving parents down until finding a place + # newitem fits. + while pos > startpos: + parentpos = (pos - 1) >> 1 + parent = heap[parentpos] + if parent < newitem: + heap[pos] = parent + pos = parentpos + continue + break + heap[pos] = newitem + +def _siftup_max(heap, pos): + 'Maxheap variant of _siftup' + endpos = len(heap) + startpos = pos + newitem = heap[pos] + # Bubble up the larger child until hitting a leaf. + childpos = 2*pos + 1 # leftmost child position + while childpos < endpos: + # Set childpos to index of larger child. + rightpos = childpos + 1 + if rightpos < endpos and not heap[rightpos] < heap[childpos]: + childpos = rightpos + # Move the larger child up. + heap[pos] = heap[childpos] + pos = childpos + childpos = 2*pos + 1 + # The leaf at pos is empty now. Put newitem there, and bubble it up + # to its final resting place (by sifting its parents down). + heap[pos] = newitem + _siftdown_max(heap, startpos, pos) + +def merge(iterables, key=None, reverse=False): + '''Merge multiple sorted inputs into a single sorted output. + + Similar to sorted(itertools.chain(*iterables)) but returns a generator, + does not pull the data into memory all at once, and assumes that each of + the input streams is already sorted (smallest to largest). + + >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25])) + [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25] + + If *key* is not None, applies a key function to each element to determine + its sort order. + + >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len)) + ['dog', 'cat', 'fish', 'horse', 'kangaroo'] + + ''' + + h = [] + h_append = h.append + + if reverse: + _heapify = _heapify_max + _heappop = _heappop_max + _heapreplace = _heapreplace_max + direction = -1 + else: + _heapify = heapify + _heappop = heappop + _heapreplace = heapreplace + direction = 1 + + if key is None: + for order, it in enumerate(map(iter, iterables)): + try: + next = it.next + h_append([next(), order * direction, next]) + except StopIteration: + pass + _heapify(h) + while len(h) > 1: + try: + while True: + value, order, next = s = h[0] + yield value + s[0] = next() # raises StopIteration when exhausted + _heapreplace(h, s) # restore heap condition + except StopIteration: + _heappop(h) # remove empty iterator + if h: + # fast case when only a single iterator remains + value, order, next = h[0] + yield value + for value in next.__self__: + yield value + return + + for order, it in enumerate(map(iter, iterables)): + try: + next = it.next + value = next() + h_append([key(value), order * direction, value, next]) + except StopIteration: + pass + _heapify(h) + while len(h) > 1: + try: + while True: + key_value, order, value, next = s = h[0] + yield value + value = next() + s[0] = key(value) + s[2] = value + _heapreplace(h, s) + except StopIteration: + _heappop(h) + if h: + key_value, order, value, next = h[0] + yield value + for value in next.__self__: + yield value + + +# Algorithm notes for nlargest() and nsmallest() +# ============================================== +# +# Make a single pass over the data while keeping the k most extreme values +# in a heap. Memory consumption is limited to keeping k values in a list. +# +# Measured performance for random inputs: +# +# number of comparisons +# n inputs k-extreme values (average of 5 trials) % more than min() +# ------------- ---------------- --------------------- ----------------- +# 1,000 100 3,317 231.7% +# 10,000 100 14,046 40.5% +# 100,000 100 105,749 5.7% +# 1,000,000 100 1,007,751 0.8% +# 10,000,000 100 10,009,401 0.1% +# +# Theoretical number of comparisons for k smallest of n random inputs: +# +# Step Comparisons Action +# ---- -------------------------- --------------------------- +# 1 1.66 * k heapify the first k-inputs +# 2 n - k compare remaining elements to top of heap +# 3 k * (1 + lg2(k)) * ln(n/k) replace the topmost value on the heap +# 4 k * lg2(k) - (k/2) final sort of the k most extreme values +# +# Combining and simplifying for a rough estimate gives: +# +# comparisons = n + k * (log(k, 2) * log(n/k) + log(k, 2) + log(n/k)) +# +# Computing the number of comparisons for step 3: +# ----------------------------------------------- +# * For the i-th new value from the iterable, the probability of being in the +# k most extreme values is k/i. For example, the probability of the 101st +# value seen being in the 100 most extreme values is 100/101. +# * If the value is a new extreme value, the cost of inserting it into the +# heap is 1 + log(k, 2). +# * The probabilty times the cost gives: +# (k/i) * (1 + log(k, 2)) +# * Summing across the remaining n-k elements gives: +# sum((k/i) * (1 + log(k, 2)) for i in range(k+1, n+1)) +# * This reduces to: +# (H(n) - H(k)) * k * (1 + log(k, 2)) +# * Where H(n) is the n-th harmonic number estimated by: +# gamma = 0.5772156649 +# H(n) = log(n, e) + gamma + 1 / (2 * n) +# http://en.wikipedia.org/wiki/Harmonic_series_(mathematics)#Rate_of_divergence +# * Substituting the H(n) formula: +# comparisons = k * (1 + log(k, 2)) * (log(n/k, e) + (1/n - 1/k) / 2) +# +# Worst-case for step 3: +# ---------------------- +# In the worst case, the input data is reversed sorted so that every new element +# must be inserted in the heap: +# +# comparisons = 1.66 * k + log(k, 2) * (n - k) +# +# Alternative Algorithms +# ---------------------- +# Other algorithms were not used because they: +# 1) Took much more auxiliary memory, +# 2) Made multiple passes over the data. +# 3) Made more comparisons in common cases (small k, large n, semi-random input). +# See the more detailed comparison of approach at: +# http://code.activestate.com/recipes/577573-compare-algorithms-for-heapqsmallest + +def nsmallest(n, iterable, key=None): + """Find the n smallest elements in a dataset. + + Equivalent to: sorted(iterable, key=key)[:n] + """ + + # Short-cut for n==1 is to use min() + if n == 1: + it = iter(iterable) + sentinel = object() + if key is None: + result = min(it, default=sentinel) + else: + result = min(it, default=sentinel, key=key) + return [] if result is sentinel else [result] + + # When n>=size, it's faster to use sorted() + try: + size = len(iterable) + except (TypeError, AttributeError): + pass + else: + if n >= size: + return sorted(iterable, key=key)[:n] + + # When key is none, use simpler decoration + if key is None: + it = iter(iterable) + # put the range(n) first so that zip() doesn't + # consume one too many elements from the iterator + result = [(elem, i) for i, elem in zip(range(n), it)] + if not result: + return result + _heapify_max(result) + top = result[0][0] + order = n + _heapreplace = _heapreplace_max + for elem in it: + if elem < top: + _heapreplace(result, (elem, order)) + top = result[0][0] + order += 1 + result.sort() + return [r[0] for r in result] + + # General case, slowest method + it = iter(iterable) + result = [(key(elem), i, elem) for i, elem in zip(range(n), it)] + if not result: + return result + _heapify_max(result) + top = result[0][0] + order = n + _heapreplace = _heapreplace_max + for elem in it: + k = key(elem) + if k < top: + _heapreplace(result, (k, order, elem)) + top = result[0][0] + order += 1 + result.sort() + return [r[2] for r in result] + +def nlargest(n, iterable, key=None): + """Find the n largest elements in a dataset. + + Equivalent to: sorted(iterable, key=key, reverse=True)[:n] + """ + + # Short-cut for n==1 is to use max() + if n == 1: + it = iter(iterable) + sentinel = object() + if key is None: + result = max(it, default=sentinel) + else: + result = max(it, default=sentinel, key=key) + return [] if result is sentinel else [result] + + # When n>=size, it's faster to use sorted() + try: + size = len(iterable) + except (TypeError, AttributeError): + pass + else: + if n >= size: + return sorted(iterable, key=key, reverse=True)[:n] + + # When key is none, use simpler decoration + if key is None: + it = iter(iterable) + result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)] + if not result: + return result + heapify(result) + top = result[0][0] + order = -n + _heapreplace = heapreplace + for elem in it: + if top < elem: + _heapreplace(result, (elem, order)) + top = result[0][0] + order -= 1 + result.sort(reverse=True) + return [r[0] for r in result] + + # General case, slowest method + it = iter(iterable) + result = [(key(elem), i, elem) for i, elem in zip(range(0, -n, -1), it)] + if not result: + return result + heapify(result) + top = result[0][0] + order = -n + _heapreplace = heapreplace + for elem in it: + k = key(elem) + if top < k: + _heapreplace(result, (k, order, elem)) + top = result[0][0] + order -= 1 + result.sort(reverse=True) + return [r[2] for r in result] + +# If available, use C implementation +try: + from _heapq import * +except ImportError: + pass +try: + from _heapq import _heapreplace_max +except ImportError: + pass +try: + from _heapq import _heapify_max +except ImportError: + pass +try: + from _heapq import _heappop_max +except ImportError: + pass + + +if __name__ == "__main__": + + import doctest + print(doctest.testmod()) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c7f7c1fe591b0..9c70fa5c16d0c 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -15,6 +15,7 @@ # limitations under the License. # +import atexit import os import sys import signal @@ -54,14 +55,37 @@ def preexec_func(): gateway_port = proc.stdout.readline() gateway_port = int(gateway_port) except ValueError: + # Grab the remaining lines of stdout (stdout, _) = proc.communicate() exit_code = proc.poll() error_msg = "Launching GatewayServer failed" - error_msg += " with exit code %d! " % exit_code if exit_code else "! " - error_msg += "(Warning: unexpected output detected.)\n\n" - error_msg += gateway_port + stdout + error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n" + error_msg += "Warning: Expected GatewayServer to output a port, but found " + if gateway_port == "" and stdout == "": + error_msg += "no output.\n" + else: + error_msg += "the following:\n\n" + error_msg += "--------------------------------------------------------------\n" + error_msg += gateway_port + stdout + error_msg += "--------------------------------------------------------------\n" raise Exception(error_msg) + # In Windows, ensure the Java child processes do not linger after Python has exited. + # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when + # the parent process' stdin sends an EOF). In Windows, however, this is not possible + # because java.lang.Process reads directly from the parent process' stdin, contending + # with any opportunity to read an EOF from the parent. Note that this is only best + # effort and will not take effect if the python process is violently terminated. + if on_windows: + # In Windows, the child process here is "spark-submit.cmd", not the JVM itself + # (because the UNIX "exec" command is not available). This means we cannot simply + # call proc.kill(), which kills only the "spark-submit.cmd" process but not the + # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all + # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) + def killChild(): + Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) + atexit.register(killChild) + # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: class EchoOutputThread(Thread): diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index ffdda7ee19302..71ab46b61d7fa 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -30,6 +30,10 @@ from math import exp, log +__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel', + 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] + + class LogisticRegressionModel(LinearModel): """A linear binary classification model derived from logistic regression. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a0630d1d5c58b..f3e952a1d842a 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -25,6 +25,8 @@ _get_initial_weights, _serialize_rating, _regression_train_wrapper from pyspark.mllib.linalg import SparseVector +__all__ = ['KMeansModel', 'KMeans'] + class KMeansModel(object): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f485a69db1fa2..e69051c104e37 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -27,6 +27,9 @@ from numpy import array, array_equal, ndarray, float64, int32 +__all__ = ['SparseVector', 'Vectors'] + + class SparseVector(object): """ diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 4dc1a4a912421..d53c95fd59c25 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -25,7 +25,10 @@ from pyspark.serializers import NoOpSerializer -class RandomRDDs: +__all__ = ['RandomRDDs', ] + + +class RandomRDDs(object): """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index e863fc249ec36..2df23394da6f8 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -24,6 +24,8 @@ _serialize_tuple, RatingDeserializer from pyspark.rdd import RDD +__all__ = ['MatrixFactorizationModel', 'ALS'] + class MatrixFactorizationModel(object): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index d8792cf44872f..f572dcfb840b6 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -17,15 +17,15 @@ from numpy import array, ndarray from pyspark import SparkContext -from pyspark.mllib._common import \ - _dot, _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ - _serialize_double_matrix, _deserialize_double_matrix, \ - _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ +from pyspark.mllib._common import _dot, _regression_train_wrapper, \ _linear_predictor_typecheck, _have_scipy, _scipy_issparse from pyspark.mllib.linalg import SparseVector, Vectors +__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel' + 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] + + class LabeledPoint(object): """ diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index feef0d16cd644..8c726f171c978 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -21,8 +21,10 @@ from pyspark.mllib._common import \ _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ - _serialize_double, _serialize_double_vector, \ - _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector + _serialize_double, _deserialize_double_matrix, _deserialize_double_vector + + +__all__ = ['MultivariateStatisticalSummary', 'Statistics'] class MultivariateStatisticalSummary(object): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index e9d778df5a24b..ccc000ac70ba6 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -26,6 +26,9 @@ from pyspark.serializers import NoOpSerializer +__all__ = ['DecisionTreeModel', 'DecisionTree'] + + class DecisionTreeModel(object): """ @@ -88,6 +91,7 @@ class DecisionTree(object): It will probably be modified for Spark v1.2. Example usage: + >>> from numpy import array >>> import sys >>> from pyspark.mllib.regression import LabeledPoint @@ -134,7 +138,7 @@ class DecisionTree(object): @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=4, maxBins=100): + impurity="gini", maxDepth=5, maxBins=32): """ Train a DecisionTreeModel for classification. @@ -166,7 +170,7 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, @staticmethod def trainRegressor(data, categoricalFeaturesInfo, - impurity="variance", maxDepth=4, maxBins=100): + impurity="variance", maxDepth=5, maxBins=32): """ Train a DecisionTreeModel for regression. diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 4962d05491c03..1c7b8c809ab5b 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -25,7 +25,7 @@ from pyspark.serializers import NoOpSerializer -class MLUtils: +class MLUtils(object): """ Helper methods to load, save and pre-process data used in MLlib. diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 3eefc878d274e..5667154cb84a8 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -32,7 +32,7 @@ import heapq import bisect from random import Random -from math import sqrt, log +from math import sqrt, log, isinf, isnan from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ @@ -44,10 +44,11 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ - get_used_memory + get_used_memory, ExternalSorter from py4j.java_collections import ListConverter, MapConverter + __all__ = ["RDD"] @@ -62,7 +63,7 @@ def portable_hash(x): >>> portable_hash(None) 0 - >>> portable_hash((None, 1)) + >>> portable_hash((None, 1)) & 0xffffffff 219750521 """ if x is None: @@ -72,7 +73,7 @@ def portable_hash(x): for i in x: h ^= portable_hash(i) h *= 1000003 - h &= 0xffffffff + h &= sys.maxint h ^= len(x) if h == -1: h = -2 @@ -131,74 +132,20 @@ def __exit__(self, type, value, tb): self._context._jsc.setCallSite(None) -class MaxHeapQ(object): - - """ - An implementation of MaxHeap. - - >>> import pyspark.rdd - >>> heap = pyspark.rdd.MaxHeapQ(5) - >>> [heap.insert(i) for i in range(10)] - [None, None, None, None, None, None, None, None, None, None] - >>> sorted(heap.getElements()) - [0, 1, 2, 3, 4] - >>> heap = pyspark.rdd.MaxHeapQ(5) - >>> [heap.insert(i) for i in range(9, -1, -1)] - [None, None, None, None, None, None, None, None, None, None] - >>> sorted(heap.getElements()) - [0, 1, 2, 3, 4] - >>> heap = pyspark.rdd.MaxHeapQ(1) - >>> [heap.insert(i) for i in range(9, -1, -1)] - [None, None, None, None, None, None, None, None, None, None] - >>> heap.getElements() - [0] +class BoundedFloat(float): """ + Bounded value is generated by approximate job, with confidence and low + bound and high bound. - def __init__(self, maxsize): - # We start from q[1], so its children are always 2 * k - self.q = [0] - self.maxsize = maxsize - - def _swim(self, k): - while (k > 1) and (self.q[k / 2] < self.q[k]): - self._swap(k, k / 2) - k = k / 2 - - def _swap(self, i, j): - t = self.q[i] - self.q[i] = self.q[j] - self.q[j] = t - - def _sink(self, k): - N = self.size() - while 2 * k <= N: - j = 2 * k - # Here we test if both children are greater than parent - # if not swap with larger one. - if j < N and self.q[j] < self.q[j + 1]: - j = j + 1 - if(self.q[k] > self.q[j]): - break - self._swap(k, j) - k = j - - def size(self): - return len(self.q) - 1 - - def insert(self, value): - if (self.size()) < self.maxsize: - self.q.append(value) - self._swim(self.size()) - else: - self._replaceRoot(value) - - def getElements(self): - return self.q[1:] - - def _replaceRoot(self, value): - if(self.q[1] > value): - self.q[1] = value - self._sink(1) + >>> BoundedFloat(100.0, 0.95, 95.0, 105.0) + 100.0 + """ + def __new__(cls, mean, confidence, low, high): + obj = float.__new__(cls, mean) + obj.confidence = confidence + obj.low = low + obj.high = high + return obj def _parse_memory(s): @@ -232,6 +179,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() + self._partitionFunc = None def _toPickleSerialization(self): if (self._jrdd_deserializer == PickleSerializer() or @@ -264,11 +212,16 @@ def cache(self): self.persist(StorageLevel.MEMORY_ONLY_SER) return self - def persist(self, storageLevel): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): """ Set this RDD's storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + + >>> rdd = sc.parallelize(["b", "a", "c"]) + >>> rdd.persist().is_cached + True """ self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) @@ -309,8 +262,6 @@ def getCheckpointFile(self): checkpointFile = self._jrdd.rdd().getCheckpointFile() if checkpointFile.isDefined(): return checkpointFile.get() - else: - return None def map(self, f, preservesPartitioning=False): """ @@ -350,7 +301,7 @@ def mapPartitions(self, f, preservesPartitioning=False): """ def func(s, iterator): return f(iterator) - return self.mapPartitionsWithIndex(func) + return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitionsWithIndex(self, f, preservesPartitioning=False): """ @@ -400,7 +351,7 @@ def filter(self, f): """ def func(iterator): return ifilter(f, iterator) - return self.mapPartitions(func) + return self.mapPartitions(func, True) def distinct(self): """ @@ -545,7 +496,7 @@ def intersection(self, other): """ return self.map(lambda v: (v, None)) \ .cogroup(other.map(lambda v: (v, None))) \ - .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \ + .filter(lambda (k, vs): all(vs)) \ .keys() def _reserialize(self, serializer=None): @@ -569,6 +520,30 @@ def __add__(self, other): raise TypeError return self.union(other) + def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash, + ascending=True, keyfunc=lambda x: x): + """ + Repartition the RDD according to the given partitioner and, within each resulting partition, + sort records by their keys. + + >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)]) + >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2) + >>> rdd2.glom().collect() + [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]] + """ + if numPartitions is None: + numPartitions = self._defaultReducePartitions() + + spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true") + memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + serializer = self._jrdd_deserializer + + def sortPartition(iterator): + sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + + return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) + def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. @@ -589,13 +564,18 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() + spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') + memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + serializer = self._jrdd_deserializer + def sortPartition(iterator): - return iter(sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=not ascending)) + sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) if numPartitions == 1: if self.getNumPartitions() > 1: self = self.coalesce(1) - return self.mapPartitions(sortPartition) + return self.mapPartitions(sortPartition, True) # first compute the boundary of each part via sampling: we want to partition # the key-space into bins such that the bins have roughly the same @@ -700,8 +680,8 @@ def foreach(self, f): def processPartition(iterator): for x in iterator: f(x) - yield None - self.mapPartitions(processPartition).collect() # Force evaluation + return iter([]) + self.mapPartitions(processPartition).count() # Force evaluation def foreachPartition(self, f): """ @@ -710,10 +690,15 @@ def foreachPartition(self, f): >>> def f(iterator): ... for x in iterator: ... print x - ... yield None >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ - self.mapPartitions(f).collect() # Force evaluation + def func(it): + r = f(it) + try: + return iter(r) + except TypeError: + return iter([]) + self.mapPartitions(func).count() # Force evaluation def collect(self): """ @@ -746,18 +731,23 @@ def reduce(self, f): 15 >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) 10 + >>> sc.parallelize([]).reduce(add) + Traceback (most recent call last): + ... + ValueError: Can not reduce() empty RDD """ def func(iterator): - acc = None - for obj in iterator: - if acc is None: - acc = obj - else: - acc = f(obj, acc) - if acc is not None: - yield acc + iterator = iter(iterator) + try: + initial = next(iterator) + except StopIteration: + return + yield reduce(f, iterator, initial) + vals = self.mapPartitions(func).collect() - return reduce(f, vals) + if vals: + return reduce(f, vals) + raise ValueError("Can not reduce() empty RDD") def fold(self, zeroValue, op): """ @@ -810,23 +800,37 @@ def func(iterator): return self.mapPartitions(func).fold(zeroValue, combOp) - def max(self): + def max(self, key=None): """ Find the maximum item in this RDD. - >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max() + @param key: A function used to generate key for comparing + + >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) + >>> rdd.max() 43.0 + >>> rdd.max(key=str) + 5.0 """ - return self.reduce(max) + if key is None: + return self.reduce(max) + return self.reduce(lambda a, b: max(a, b, key=key)) - def min(self): + def min(self, key=None): """ Find the minimum item in this RDD. - >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min() - 1.0 + @param key: A function used to generate key for comparing + + >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0]) + >>> rdd.min() + 2.0 + >>> rdd.min(key=str) + 10.0 """ - return self.reduce(min) + if key is None: + return self.reduce(min) + return self.reduce(lambda a, b: min(a, b, key=key)) def sum(self): """ @@ -856,6 +860,133 @@ def redFunc(left_counter, right_counter): return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc) + def histogram(self, buckets): + """ + Compute a histogram using the provided buckets. The buckets + are all open to the right except for the last which is closed. + e.g. [1,10,20,50] means the buckets are [1,10) [10,20) [20,50], + which means 1<=x<10, 10<=x<20, 20<=x<=50. And on the input of 1 + and 50 we would have a histogram of 1,0,1. + + If your histogram is evenly spaced (e.g. [0, 10, 20, 30]), + this can be switched from an O(log n) inseration to O(1) per + element(where n = # buckets). + + Buckets must be sorted and not contain any duplicates, must be + at least two elements. + + If `buckets` is a number, it will generates buckets which are + evenly spaced between the minimum and maximum of the RDD. For + example, if the min value is 0 and the max is 100, given buckets + as 2, the resulting buckets will be [0,50) [50,100]. buckets must + be at least 1 If the RDD contains infinity, NaN throws an exception + If the elements in RDD do not vary (max == min) always returns + a single bucket. + + It will return an tuple of buckets and histogram. + + >>> rdd = sc.parallelize(range(51)) + >>> rdd.histogram(2) + ([0, 25, 50], [25, 26]) + >>> rdd.histogram([0, 5, 25, 50]) + ([0, 5, 25, 50], [5, 20, 26]) + >>> rdd.histogram([0, 15, 30, 45, 60]) # evenly spaced buckets + ([0, 15, 30, 45, 60], [15, 15, 15, 6]) + >>> rdd = sc.parallelize(["ab", "ac", "b", "bd", "ef"]) + >>> rdd.histogram(("a", "b", "c")) + (('a', 'b', 'c'), [2, 2]) + """ + + if isinstance(buckets, (int, long)): + if buckets < 1: + raise ValueError("number of buckets must be >= 1") + + # filter out non-comparable elements + def comparable(x): + if x is None: + return False + if type(x) is float and isnan(x): + return False + return True + + filtered = self.filter(comparable) + + # faster than stats() + def minmax(a, b): + return min(a[0], b[0]), max(a[1], b[1]) + try: + minv, maxv = filtered.map(lambda x: (x, x)).reduce(minmax) + except TypeError as e: + if " empty " in str(e): + raise ValueError("can not generate buckets from empty RDD") + raise + + if minv == maxv or buckets == 1: + return [minv, maxv], [filtered.count()] + + try: + inc = (maxv - minv) / buckets + except TypeError: + raise TypeError("Can not generate buckets with non-number in RDD") + + if isinf(inc): + raise ValueError("Can not generate buckets with infinite value") + + # keep them as integer if possible + if inc * buckets != maxv - minv: + inc = (maxv - minv) * 1.0 / buckets + + buckets = [i * inc + minv for i in range(buckets)] + buckets.append(maxv) # fix accumulated error + even = True + + elif isinstance(buckets, (list, tuple)): + if len(buckets) < 2: + raise ValueError("buckets should have more than one value") + + if any(i is None or isinstance(i, float) and isnan(i) for i in buckets): + raise ValueError("can not have None or NaN in buckets") + + if sorted(buckets) != list(buckets): + raise ValueError("buckets should be sorted") + + if len(set(buckets)) != len(buckets): + raise ValueError("buckets should not contain duplicated values") + + minv = buckets[0] + maxv = buckets[-1] + even = False + inc = None + try: + steps = [buckets[i + 1] - buckets[i] for i in range(len(buckets) - 1)] + except TypeError: + pass # objects in buckets do not support '-' + else: + if max(steps) - min(steps) < 1e-10: # handle precision errors + even = True + inc = (maxv - minv) / (len(buckets) - 1) + + else: + raise TypeError("buckets should be a list or tuple or number(int or long)") + + def histogram(iterator): + counters = [0] * len(buckets) + for i in iterator: + if i is None or (type(i) is float and isnan(i)) or i > maxv or i < minv: + continue + t = (int((i - minv) / inc) if even + else bisect.bisect_right(buckets, i) - 1) + counters[t] += 1 + # add last two together + last = counters.pop() + counters[-1] += last + return [counters] + + def mergeCounters(a, b): + return [i + j for i, j in zip(a, b)] + + return buckets, self.mapPartitions(histogram).reduce(mergeCounters) + def mean(self): """ Compute the mean of this RDD's elements. @@ -919,12 +1050,12 @@ def countPartition(iterator): yield counts def mergeMaps(m1, m2): - for (k, v) in m2.iteritems(): + for k, v in m2.iteritems(): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) - def top(self, num): + def top(self, num, key=None): """ Get the top N elements from a RDD. @@ -933,20 +1064,16 @@ def top(self, num): [12] >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2) [6, 5] + >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str) + [4, 3, 2] """ def topIterator(iterator): - q = [] - for k in iterator: - if len(q) < num: - heapq.heappush(q, k) - else: - heapq.heappushpop(q, k) - yield q + yield heapq.nlargest(num, iterator, key=key) def merge(a, b): - return next(topIterator(a + b)) + return heapq.nlargest(num, a + b, key=key) - return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True) + return self.mapPartitions(topIterator).reduce(merge) def takeOrdered(self, num, key=None): """ @@ -959,24 +1086,10 @@ def takeOrdered(self, num, key=None): [10, 9, 7, 6, 5, 4] """ - def topNKeyedElems(iterator, key_=None): - q = MaxHeapQ(num) - for k in iterator: - if key_ is not None: - k = (key_(k), k) - q.insert(k) - yield q.getElements() - - def unKey(x, key_=None): - if key_ is not None: - x = [i[1] for i in x] - return x - def merge(a, b): - return next(topNKeyedElems(a + b)) - result = self.mapPartitions( - lambda i: topNKeyedElems(i, key)).reduce(merge) - return sorted(unKey(result, key), key=key) + return heapq.nsmallest(num, a + b, key) + + return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge) def take(self, num): """ @@ -1005,24 +1118,24 @@ def take(self, num): # we actually cap it at totalParts in runJob. numPartsToTry = 1 if partsScanned > 0: - # If we didn't find any rows after the first iteration, just - # try all partitions next. Otherwise, interpolate the number - # of partitions we need to try, but overestimate it by 50%. + # If we didn't find any rows after the previous iteration, + # quadruple and retry. Otherwise, interpolate the number of + # partitions we need to try, but overestimate it by 50%. if len(items) == 0: - numPartsToTry = totalParts - 1 + numPartsToTry = partsScanned * 4 else: numPartsToTry = int(1.5 * num * partsScanned / len(items)) left = num - len(items) def takeUpToNumLeft(iterator): + iterator = iter(iterator) taken = 0 while taken < left: yield next(iterator) taken += 1 - p = range( - partsScanned, min(partsScanned + numPartsToTry, totalParts)) + p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) res = self.context.runJob(self, takeUpToNumLeft, p, True) items += res @@ -1036,8 +1149,15 @@ def first(self): >>> sc.parallelize([2, 3, 4]).first() 2 + >>> sc.parallelize([]).first() + Traceback (most recent call last): + ... + ValueError: RDD is empty """ - return self.take(1)[0] + rs = self.take(1) + if rs: + return rs[0] + raise ValueError("RDD is empty") def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1262,13 +1382,13 @@ def reduceByKeyLocally(self, func): """ def reducePartition(iterator): m = {} - for (k, v) in iterator: - m[k] = v if k not in m else func(m[k], v) + for k, v in iterator: + m[k] = func(m[k], v) if k in m else v yield m def mergeMaps(m1, m2): - for (k, v) in m2.iteritems(): - m1[k] = v if k not in m1 else func(m1[k], v) + for k, v in m2.iteritems(): + m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1365,7 +1485,7 @@ def add_shuffle_key(split, iterator): buckets = defaultdict(list) c, batch = 0, min(10 * numPartitions, 1000) - for (k, v) in iterator: + for k, v in iterator: buckets[partitionFunc(k) % numPartitions].append((k, v)) c += 1 @@ -1388,7 +1508,7 @@ def add_shuffle_key(split, iterator): batch = max(batch / 1.5, 1) c = 0 - for (split, items) in buckets.iteritems(): + for split, items in buckets.iteritems(): yield pack_long(split) yield outputSerializer.dumps(items) @@ -1458,7 +1578,7 @@ def _mergeCombiners(iterator): merger.mergeCombiners(iterator) return merger.iteritems() - return shuffled.mapPartitions(_mergeCombiners) + return shuffled.mapPartitions(_mergeCombiners, True) def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ @@ -1522,7 +1642,6 @@ def mergeCombiners(a, b): return self.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions).mapValues(lambda x: ResultIterable(x)) - # TODO: add tests def flatMapValues(self, f): """ Pass each value in the key-value pair RDD through a flatMap function @@ -1612,9 +1731,8 @@ def subtractByKey(self, other, numPartitions=None): [('b', 4), ('b', 5)] """ def filter_func((key, vals)): - return len(vals[0]) > 0 and len(vals[1]) == 0 - map_func = lambda (key, vals): [(key, val) for val in vals[0]] - return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) + return vals[0] and not vals[1] + return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0]) def subtract(self, other, numPartitions=None): """ @@ -1627,7 +1745,7 @@ def subtract(self, other, numPartitions=None): """ # note: here 'True' is just a placeholder rdd = other.map(lambda x: (x, True)) - return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) + return self.map(lambda x: (x, True)).subtractByKey(rdd, numPartitions).keys() def keyBy(self, f): """ @@ -1715,14 +1833,60 @@ def batch_as(rdd, batchSize): other._jrdd_deserializer) return RDD(pairRDD, self.ctx, deserializer) + def zipWithIndex(self): + """ + Zips this RDD with its element indices. + + The ordering is first based on the partition index and then the + ordering of items within each partition. So the first item in + the first partition gets index 0, and the last item in the last + partition receives the largest index. + + This method needs to trigger a spark job when this RDD contains + more than one partitions. + + >>> sc.parallelize(["a", "b", "c", "d"], 3).zipWithIndex().collect() + [('a', 0), ('b', 1), ('c', 2), ('d', 3)] + """ + starts = [0] + if self.getNumPartitions() > 1: + nums = self.mapPartitions(lambda it: [sum(1 for i in it)]).collect() + for i in range(len(nums) - 1): + starts.append(starts[-1] + nums[i]) + + def func(k, it): + for i, v in enumerate(it, starts[k]): + yield v, i + + return self.mapPartitionsWithIndex(func) + + def zipWithUniqueId(self): + """ + Zips this RDD with generated unique Long ids. + + Items in the kth partition will get ids k, n+k, 2*n+k, ..., where + n is the number of partitions. So there may exist gaps, but this + method won't trigger a spark job, which is different from + L{zipWithIndex} + + >>> sc.parallelize(["a", "b", "c", "d", "e"], 3).zipWithUniqueId().collect() + [('a', 0), ('b', 1), ('c', 4), ('d', 2), ('e', 5)] + """ + n = self.getNumPartitions() + + def func(k, it): + for i, v in enumerate(it): + yield v, i * n + k + + return self.mapPartitionsWithIndex(func) + def name(self): """ Return the name of this RDD. """ name_ = self._jrdd.name() - if not name_: - return None - return name_.encode('utf-8') + if name_: + return name_.encode('utf-8') def setName(self, name): """ @@ -1740,9 +1904,8 @@ def toDebugString(self): A description of this RDD and its recursive dependencies for debugging. """ debug_string = self._jrdd.toDebugString() - if not debug_string: - return None - return debug_string.encode('utf-8') + if debug_string: + return debug_string.encode('utf-8') def getStorageLevel(self): """ @@ -1777,10 +1940,122 @@ def _defaultReducePartitions(self): else: return self.getNumPartitions() - # TODO: `lookup` is disabled because we can't make direct comparisons based - # on the key; we need to compare the hash of the key to the hash of the - # keys in the pairs. This could be an expensive operation, since those - # hashes aren't retained. + def lookup(self, key): + """ + Return the list of values in the RDD for key `key`. This operation + is done efficiently if the RDD has a known partitioner by only + searching the partition that the key maps to. + + >>> l = range(1000) + >>> rdd = sc.parallelize(zip(l, l), 10) + >>> rdd.lookup(42) # slow + [42] + >>> sorted = rdd.sortByKey() + >>> sorted.lookup(42) # fast + [42] + >>> sorted.lookup(1024) + [] + """ + values = self.filter(lambda (k, v): k == key).values() + + if self._partitionFunc is not None: + return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False) + + return values.collect() + + def _is_pickled(self): + """ Return this RDD is serialized by Pickle or not. """ + der = self._jrdd_deserializer + if isinstance(der, PickleSerializer): + return True + if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer): + return True + return False + + def _to_java_object_rdd(self): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + if not self._is_pickled(): + self = self._reserialize(BatchedSerializer(PickleSerializer(), 1024)) + batched = isinstance(self._jrdd_deserializer, BatchedSerializer) + return self.ctx._jvm.PythonRDD.pythonToJava(self._jrdd, batched) + + def countApprox(self, timeout, confidence=0.95): + """ + :: Experimental :: + Approximate version of count() that returns a potentially incomplete + result within a timeout, even if not all tasks have finished. + + >>> rdd = sc.parallelize(range(1000), 10) + >>> rdd.countApprox(1000, 1.0) + 1000 + """ + drdd = self.mapPartitions(lambda it: [float(sum(1 for i in it))]) + return int(drdd.sumApprox(timeout, confidence)) + + def sumApprox(self, timeout, confidence=0.95): + """ + :: Experimental :: + Approximate operation to return the sum within a timeout + or meet the confidence. + + >>> rdd = sc.parallelize(range(1000), 10) + >>> r = sum(xrange(1000)) + >>> (rdd.sumApprox(1000) - r) / r < 0.05 + True + """ + jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd() + jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd()) + r = jdrdd.sumApprox(timeout, confidence).getFinalValue() + return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high()) + + def meanApprox(self, timeout, confidence=0.95): + """ + :: Experimental :: + Approximate operation to return the mean within a timeout + or meet the confidence. + + >>> rdd = sc.parallelize(range(1000), 10) + >>> r = sum(xrange(1000)) / 1000.0 + >>> (rdd.meanApprox(1000) - r) / r < 0.05 + True + """ + jrdd = self.map(float)._to_java_object_rdd() + jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd()) + r = jdrdd.meanApprox(timeout, confidence).getFinalValue() + return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high()) + + def countApproxDistinct(self, relativeSD=0.05): + """ + :: Experimental :: + Return approximate number of distinct elements in the RDD. + + The algorithm used is based on streamlib's implementation of + "HyperLogLog in Practice: Algorithmic Engineering of a State + of The Art Cardinality Estimation Algorithm", available + here. + + @param relativeSD Relative accuracy. Smaller values create + counters that require more space. + It must be greater than 0.000017. + + >>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct() + >>> 950 < n < 1050 + True + >>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct() + >>> 18 < n < 22 + True + """ + if relativeSD < 0.000017: + raise ValueError("relativeSD should be greater than 0.000017") + if relativeSD > 0.37: + raise ValueError("relativeSD should be smaller than 0.37") + # the hash space in Java is 2^32 + hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) + return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) class PipelinedRDD(RDD): @@ -1824,8 +2099,10 @@ def pipeline_func(split, iterator): self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._id = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False + self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None @property def _jrdd(self): @@ -1853,6 +2130,11 @@ def _jrdd(self): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val + def id(self): + if self._id is None: + self._id = self._jrdd.id() + return self._id + def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fc49aa42dbaf9..7b2710b913128 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -409,7 +409,7 @@ def loads(self, obj): class CompressedSerializer(FramedSerializer): """ - compress the serialized data + Compress the serialized data """ def __init__(self, serializer): @@ -429,18 +429,22 @@ class UTF8Deserializer(Serializer): Deserializes streams written by String.getBytes. """ + def __init__(self, use_unicode=False): + self.use_unicode = use_unicode + def loads(self, stream): length = read_int(stream) - return stream.read(length).decode('utf8') + s = stream.read(length) + return s.decode("utf-8") if self.use_unicode else s def load_stream(self, stream): - while True: - try: + try: + while True: yield self.loads(stream) - except struct.error: - return - except EOFError: - return + except struct.error: + return + except EOFError: + return def read_long(stream): diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index e1e7cd954189f..89cf76920e353 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -28,6 +28,7 @@ sys.exit(1) +import atexit import os import platform import pyspark @@ -42,14 +43,15 @@ SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) sc = SparkContext(appName="PySparkShell", pyFiles=add_files) +atexit.register(lambda: sc.stop()) print("""Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ - /__ / .__/\_,_/_/ /_/\_\ version 1.0.0-SNAPSHOT + /__ / .__/\_,_/_/ /_/\_\ version %s /_/ -""") +""" % sc.version) print("Using Python version %s (%s, %s)" % ( platform.python_version(), platform.python_build()[0], diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 1ebe7df418327..49829f5280a5f 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -21,7 +21,10 @@ import shutil import warnings import gc +import itertools +import random +import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer try: @@ -54,6 +57,17 @@ def get_used_memory(): return 0 +def _get_local_dirs(sub): + """ Get all the directories """ + path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") + dirs = path.split(",") + if len(dirs) > 1: + # different order in different processes and instances + rnd = random.Random(os.getpid() + id(dirs)) + random.shuffle(dirs, rnd.random) + return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] + + class Aggregator(object): """ @@ -196,7 +210,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, # default serializer is only used for tests self.serializer = serializer or \ BatchedSerializer(PickleSerializer(), 1024) - self.localdirs = localdirs or self._get_dirs() + self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions # check the memory after # of items merged @@ -212,13 +226,6 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, # randomize the hash of key, id(o) is the address of o (aligned by 8) self._seed = id(self) + 7 - def _get_dirs(self): - """ Get all the directories """ - path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") - dirs = path.split(",") - return [os.path.join(d, "python", str(os.getpid()), str(id(self))) - for d in dirs] - def _get_spill_dir(self, n): """ Choose one directory for spill by number n """ return os.path.join(self.localdirs[n % len(self.localdirs)], str(n)) @@ -434,6 +441,74 @@ def _recursive_merged_items(self, start): os.remove(os.path.join(path, str(i))) +class ExternalSorter(object): + """ + ExtenalSorter will divide the elements into chunks, sort them in + memory and dump them into disks, finally merge them back. + + The spilling will only happen when the used memory goes above + the limit. + + >>> sorter = ExternalSorter(1) # 1M + >>> import random + >>> l = range(1024) + >>> random.shuffle(l) + >>> sorted(l) == list(sorter.sorted(l)) + True + >>> sorted(l) == list(sorter.sorted(l, key=lambda x: -x, reverse=True)) + True + """ + def __init__(self, memory_limit, serializer=None): + self.memory_limit = memory_limit + self.local_dirs = _get_local_dirs("sort") + self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) + self._spilled_bytes = 0 + + def _get_path(self, n): + """ Choose one directory for spill by number n """ + d = self.local_dirs[n % len(self.local_dirs)] + if not os.path.exists(d): + os.makedirs(d) + return os.path.join(d, str(n)) + + def sorted(self, iterator, key=None, reverse=False): + """ + Sort the elements in iterator, do external sort when the memory + goes above the limit. + """ + batch = 10 + chunks, current_chunk = [], [] + iterator = iter(iterator) + while True: + # pick elements in batch + chunk = list(itertools.islice(iterator, batch)) + current_chunk.extend(chunk) + if len(chunk) < batch: + break + + if get_used_memory() > self.memory_limit: + # sort them inplace will save memory + current_chunk.sort(key=key, reverse=reverse) + path = self._get_path(len(chunks)) + with open(path, 'w') as f: + self.serializer.dump_stream(current_chunk, f) + self._spilled_bytes += os.path.getsize(path) + chunks.append(self.serializer.load_stream(open(path))) + current_chunk = [] + + elif not chunks: + batch = min(batch * 2, 10000) + + current_chunk.sort(key=key, reverse=reverse) + if not chunks: + return current_chunk + + if current_chunk: + chunks.append(iter(current_chunk)) + + return heapq.merge(chunks, key=key, reverse=reverse) + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d4ca0cc8f336e..53eea6d6cf3ba 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -29,6 +29,7 @@ from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer +from pyspark.storagelevel import StorageLevel from itertools import chain, ifilter, imap @@ -40,8 +41,7 @@ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", - "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "SchemaRDD", "Row"] class DataType(object): @@ -186,15 +186,15 @@ class ArrayType(DataType): """ - def __init__(self, elementType, containsNull=False): + def __init__(self, elementType, containsNull=True): """Creates an ArrayType :param elementType: the data type of elements. :param containsNull: indicates whether the list contains None values. - >>> ArrayType(StringType) == ArrayType(StringType, False) + >>> ArrayType(StringType) == ArrayType(StringType, True) True - >>> ArrayType(StringType, True) == ArrayType(StringType) + >>> ArrayType(StringType, False) == ArrayType(StringType) False """ self.elementType = elementType @@ -899,9 +899,9 @@ def __reduce__(self): return Row -class SQLContext: +class SQLContext(object): - """Main entry point for SparkSQL functionality. + """Main entry point for Spark SQL functionality. A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as tables, execute SQL over tables, cache tables, and read parquet files. @@ -943,18 +943,16 @@ def __init__(self, sparkContext, sqlContext=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray - - if sqlContext: - self._scala_SQLContext = sqlContext + self._scala_SQLContext = sqlContext @property def _ssql_ctx(self): - """Accessor for the JVM SparkSQL context. + """Accessor for the JVM Spark SQL context. Subclasses can override this property to provide their own JVM Contexts. """ - if not hasattr(self, '_scala_SQLContext'): + if self._scala_SQLContext is None: self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext @@ -971,23 +969,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] - >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - [Row(c0=5)] """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) + pickled_command = CloudPickleSerializer().dumps(command) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self._sc._pickled_broadcast_vars], + self._sc._gateway._gateway_client) + self._sc._pickled_broadcast_vars.clear() env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, - bytearray(CloudPickleSerializer().dumps(command)), + bytearray(pickled_command), env, includes, self._sc.pythonExec, + broadcast_vars, self._sc._javaAccumulator, str(returnType)) @@ -1037,7 +1038,7 @@ def inferSchema(self, rdd): "can not infer schema") if type(first) is dict: warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.Row instead") + "please use pyspark.sql.Row instead") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) @@ -1487,12 +1488,27 @@ def __repr__(self): return "" % ", ".join(self) +def inherit_doc(cls): + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls + + +@inherit_doc class SchemaRDD(RDD): """An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can - utilize the relational query api exposed by SparkSQL. + utilize the relational query api exposed by Spark SQL. For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on directly, as it's underlying @@ -1509,7 +1525,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.sql_ctx = sql_ctx self._sc = sql_ctx._sc self._jschema_rdd = jschema_rdd - + self._id = None self.is_cached = False self.is_checkpointed = False self.ctx = self.sql_ctx._sc @@ -1527,9 +1543,10 @@ def _jrdd(self): self._lazy_jrdd = self._jschema_rdd.javaToPython() return self._lazy_jrdd - @property - def _id(self): - return self._jrdd.id() + def id(self): + if self._id is None: + self._id = self._jrdd.id() + return self._id def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. @@ -1563,6 +1580,7 @@ def registerTempTable(self, name): self._jschema_rdd.registerTempTable(name) def registerAsTable(self, name): + """DEPRECATED: use registerTempTable() instead""" warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) self.registerTempTable(name) @@ -1649,7 +1667,7 @@ def cache(self): self._jschema_rdd.cache() return self - def persist(self, storageLevel): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) self._jschema_rdd.persist(javaStorageLevel) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 2aa0fb9d2c1ed..676aa0f7144aa 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -18,7 +18,7 @@ __all__ = ["StorageLevel"] -class StorageLevel: +class StorageLevel(object): """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 51bfbb47e53c2..bb84ebe72cb24 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -30,6 +30,7 @@ import tempfile import time import zipfile +import random if sys.version_info[:2] <= (2, 6): import unittest2 as unittest @@ -37,10 +38,12 @@ import unittest +from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger +from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.sql import SQLContext, IntegerType _have_scipy = False _have_numpy = False @@ -117,6 +120,44 @@ def test_huge_dataset(self): m._cleanup() +class TestSorter(unittest.TestCase): + def test_in_memory_sort(self): + l = range(1024) + random.shuffle(l) + sorter = ExternalSorter(1024) + self.assertEquals(sorted(l), list(sorter.sorted(l))) + self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + + def test_external_sort(self): + l = range(1024) + random.shuffle(l) + sorter = ExternalSorter(1) + self.assertEquals(sorted(l), list(sorter.sorted(l))) + self.assertGreater(sorter._spilled_bytes, 0) + last = sorter._spilled_bytes + self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertGreater(sorter._spilled_bytes, last) + last = sorter._spilled_bytes + self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertGreater(sorter._spilled_bytes, last) + last = sorter._spilled_bytes + self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertGreater(sorter._spilled_bytes, last) + + def test_external_sort_in_rdd(self): + conf = SparkConf().set("spark.python.worker.memory", "1m") + sc = SparkContext(conf=conf) + l = range(10240) + random.shuffle(l) + rdd = sc.parallelize(l, 10) + self.assertEquals(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + class SerializationTestCase(unittest.TestCase): def test_namedtuple(self): @@ -128,6 +169,17 @@ def test_namedtuple(self): self.assertEquals(p1, p2) +# Regression test for SPARK-3415 +class CloudPickleTest(unittest.TestCase): + def test_pickling_file_handles(self): + from pyspark.cloudpickle import dumps + from StringIO import StringIO + from pickle import load + out1 = sys.stderr + out2 = load(StringIO(dumps(out1))) + self.assertEquals(out1, out2) + + class PySparkTestCase(unittest.TestCase): def setUp(self): @@ -240,6 +292,15 @@ def func(): class TestRDDFunctions(PySparkTestCase): + def test_id(self): + rdd = self.sc.parallelize(range(10)) + id = rdd.id() + self.assertEqual(id, rdd.id()) + rdd2 = rdd.map(str).filter(bool) + id2 = rdd2.id() + self.assertEqual(id + 1, id2) + self.assertEqual(id2, rdd2.id()) + def test_failed_sparkcontext_creation(self): # Regression test for SPARK-1550 self.sc.stop() @@ -364,6 +425,155 @@ def test_zip_with_different_number_of_items(self): self.assertEquals(a.count(), b.count()) self.assertRaises(Exception, lambda: a.zip(b).count()) + def test_count_approx_distinct(self): + rdd = self.sc.parallelize(range(1000)) + self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050) + + rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) + self.assertTrue(18 < rdd.countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) + + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.5)) + + def test_histogram(self): + # empty + rdd = self.sc.parallelize([]) + self.assertEquals([0], rdd.histogram([0, 10])[1]) + self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertRaises(ValueError, lambda: rdd.histogram(1)) + + # out of range + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEquals([0], rdd.histogram([0, 10])[1]) + self.assertEquals([0, 0], rdd.histogram((0, 4, 10))[1]) + + # in range with one bucket + rdd = self.sc.parallelize(range(1, 5)) + self.assertEquals([4], rdd.histogram([0, 10])[1]) + self.assertEquals([3, 1], rdd.histogram([0, 4, 10])[1]) + + # in range with one bucket exact match + self.assertEquals([4], rdd.histogram([1, 4])[1]) + + # out of range with two buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1]) + + # out of range with two uneven buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + + # in range with two buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two bucket and None + rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) + self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two uneven buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEquals([3, 2], rdd.histogram([0, 5, 11])[1]) + + # mixed range with two uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) + self.assertEquals([4, 3], rdd.histogram([0, 5, 11])[1]) + + # mixed range with four uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) + self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # mixed range with uneven buckets and NaN + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, + 199.0, 200.0, 200.1, None, float('nan')]) + self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # out of range with infinite buckets + rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) + self.assertEquals([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + + # invalid buckets + self.assertRaises(ValueError, lambda: rdd.histogram([])) + self.assertRaises(ValueError, lambda: rdd.histogram([1])) + self.assertRaises(ValueError, lambda: rdd.histogram(0)) + self.assertRaises(TypeError, lambda: rdd.histogram({})) + + # without buckets + rdd = self.sc.parallelize(range(1, 5)) + self.assertEquals(([1, 4], [4]), rdd.histogram(1)) + + # without buckets single element + rdd = self.sc.parallelize([1]) + self.assertEquals(([1, 1], [1]), rdd.histogram(1)) + + # without bucket no range + rdd = self.sc.parallelize([1] * 4) + self.assertEquals(([1, 1], [4]), rdd.histogram(1)) + + # without buckets basic two + rdd = self.sc.parallelize(range(1, 5)) + self.assertEquals(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + + # without buckets with more requested than elements + rdd = self.sc.parallelize([1, 2]) + buckets = [1 + 0.2 * i for i in range(6)] + hist = [1, 0, 0, 0, 1] + self.assertEquals((buckets, hist), rdd.histogram(5)) + + # invalid RDDs + rdd = self.sc.parallelize([1, float('inf')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + rdd = self.sc.parallelize([float('nan')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + + # string + rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) + self.assertEquals([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEquals((["ab", "ef"], [5]), rdd.histogram(1)) + self.assertRaises(TypeError, lambda: rdd.histogram(2)) + + # mixed RDD + rdd = self.sc.parallelize([1, 4, "ab", "ac", "b"], 2) + self.assertEquals([1, 1], rdd.histogram([0, 4, 10])[1]) + self.assertEquals([2, 1], rdd.histogram(["a", "b", "c"])[1]) + self.assertEquals(([1, "b"], [5]), rdd.histogram(1)) + self.assertRaises(TypeError, lambda: rdd.histogram(2)) + + def test_repartitionAndSortWithinPartitions(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) + partitions = repartitioned.glom().collect() + self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)]) + + +class TestSQL(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.sqlCtx = SQLContext(self.sc) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + class TestIO(PySparkTestCase): @@ -1044,6 +1254,35 @@ def test_single_script_on_cluster(self): self.assertIn("[2, 4, 6]", out) +class ContextStopTests(unittest.TestCase): + + def test_stop(self): + sc = SparkContext() + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_exception(self): + try: + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + raise Exception() + except: + pass + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_stop(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/run-tests b/python/run-tests index 7b1ee3e1cddba..d98840de59d2c 100755 --- a/python/run-tests +++ b/python/run-tests @@ -19,7 +19,7 @@ # Figure out where the Spark framework is installed -FWDIR="$(cd `dirname $0`; cd ../; pwd)" +FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" # CD into the python directory to find things on the right path cd "$FWDIR/python" @@ -28,12 +28,14 @@ FAILED=0 rm -f unit-tests.log -# Remove the metastore and warehouse directory created by the HiveContext tests in SparkSQL +# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL rm -rf metastore warehouse function run_test() { echo "Running test: $1" - SPARK_TESTING=1 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log + + SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + FAILED=$((PIPESTATUS[0]||$FAILED)) # Fail and exit on the first test failure. @@ -48,6 +50,8 @@ function run_test() { echo "Running PySpark tests. Output is in python/unit-tests.log." +export PYSPARK_PYTHON="python" + # Try to test with Python 2.6, since that's the minimum version that we support: if [ $(which python2.6) ]; then export PYSPARK_PYTHON="python2.6" diff --git a/repl/pom.xml b/repl/pom.xml index 68f4504450778..fcc5f90d870e8 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 687e85ca94d3c..5ee325008a5cd 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -21,10 +21,10 @@ import java.io.{ByteArrayOutputStream, InputStream} import java.net.{URI, URL, URLEncoder} import java.util.concurrent.{Executors, ExecutorService} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils import org.apache.spark.util.ParentClassLoader @@ -36,7 +36,7 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ * used to load classes defined by the interpreter when the REPL is used. * Allows the user to specify if user class path should be first */ -class ExecutorClassLoader(classUri: String, parent: ClassLoader, +class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader { val uri = new URI(classUri) val directory = uri.getPath @@ -48,7 +48,7 @@ class ExecutorClassLoader(classUri: String, parent: ClassLoader, if (uri.getScheme() == "http") { null } else { - FileSystem.get(uri, new Configuration()) + FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) } } diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 65788f4646d91..e56b74edba88c 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -8,28 +8,32 @@ package org.apache.spark.repl +import java.net.URL + +import scala.reflect.io.AbstractFile import scala.tools.nsc._ +import scala.tools.nsc.backend.JavaPlatform import scala.tools.nsc.interpreter._ -import scala.tools.nsc.interpreter.{ Results => IR } -import Predef.{ println => _, _ } -import java.io.{ BufferedReader, FileReader } +import scala.tools.nsc.interpreter.{Results => IR} +import Predef.{println => _, _} +import java.io.{BufferedReader, FileReader} +import java.net.URI import java.util.concurrent.locks.ReentrantLock import scala.sys.process.Process import scala.tools.nsc.interpreter.session._ -import scala.util.Properties.{ jdkHome, javaVersion } -import scala.tools.util.{ Javap } +import scala.util.Properties.{jdkHome, javaVersion} +import scala.tools.util.{Javap} import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import scala.concurrent.ops -import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import scala.tools.nsc.util._ import scala.tools.nsc.interpreter._ -import scala.tools.nsc.io.{ File, Directory } +import scala.tools.nsc.io.{File, Directory} import scala.reflect.NameTransformer._ -import scala.tools.nsc.util.ScalaClassLoader import scala.tools.nsc.util.ScalaClassLoader._ import scala.tools.util._ -import scala.language.{implicitConversions, existentials} +import scala.language.{implicitConversions, existentials, postfixOps} import scala.reflect.{ClassTag, classTag} import scala.tools.reflect.StdRuntimeTags._ @@ -186,8 +190,16 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, require(settings != null) if (addedClasspath != "") settings.classpath.append(addedClasspath) + val addedJars = + if (Utils.isWindows) { + // Strip any URI scheme prefix so we can add the correct path to the classpath + // e.g. file:/C:/my/path.jar -> C:/my/path.jar + SparkILoop.getAddedJars.map { jar => new URI(jar).getPath.stripPrefix("/") } + } else { + SparkILoop.getAddedJars + } // work around for Scala bug - val totalClassPath = SparkILoop.getAddedJars.foldLeft( + val totalClassPath = addedJars.foldLeft( settings.classpath.value)((l, r) => ClassPath.join(l, r)) this.settings.classpath.value = totalClassPath @@ -711,22 +723,24 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, added = true addedClasspath = ClassPath.join(addedClasspath, f.path) totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + intp.addUrlsToClassPath(f.toURI.toURL) + sparkContext.addJar(f.toURI.toURL.getPath) } } - if (added) replay() } def addClasspath(arg: String): Unit = { val f = File(arg).normalize if (f.exists) { addedClasspath = ClassPath.join(addedClasspath, f.path) - val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) - echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) - replay() + intp.addUrlsToClassPath(f.toURI.toURL) + sparkContext.addJar(f.toURI.toURL.getPath) + echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, intp.global.classPath.asClasspathString)) } else echo("The path '" + f + "' doesn't seem to exist.") } + def powerCmd(): Result = { if (isReplPower) "Already in power mode." else enablePowerMode(false) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 910b31d209e13..7667a9c11979e 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -14,6 +14,8 @@ import scala.reflect.internal.util.Position import scala.util.control.Exception.ignoring import scala.tools.nsc.util.stackTraceString +import org.apache.spark.SPARK_VERSION + /** * Machinery for the asynchronous initialization of the repl. */ @@ -26,9 +28,9 @@ trait SparkILoopInit { ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version 1.0.0-SNAPSHOT + /___/ .__/\_,_/_/ /_/\_\ version %s /_/ -""") +""".format(SPARK_VERSION)) import Properties._ val welcomeMsg = "Using Scala %s (%s, Java %s)".format( versionString, javaVmName, javaVersion) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 84b57cd2dc1af..6ddb6accd696b 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -7,11 +7,14 @@ package org.apache.spark.repl +import java.io.File + import scala.tools.nsc._ +import scala.tools.nsc.backend.JavaPlatform import scala.tools.nsc.interpreter._ import Predef.{ println => _, _ } -import util.stringFromWriter +import scala.tools.nsc.util.{MergedClassPath, stringFromWriter, ScalaClassLoader, stackTraceString} import scala.reflect.internal.util._ import java.net.URL import scala.sys.BooleanProp @@ -21,7 +24,6 @@ import reporters._ import symtab.Flags import scala.reflect.internal.Names import scala.tools.util.PathResolver -import scala.tools.nsc.util.ScalaClassLoader import ScalaClassLoader.URLClassLoader import scala.tools.nsc.util.Exceptional.unwrap import scala.collection.{ mutable, immutable } @@ -34,7 +36,6 @@ import scala.reflect.runtime.{ universe => ru } import scala.reflect.{ ClassTag, classTag } import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable -import util.stackTraceString import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils @@ -130,6 +131,9 @@ import org.apache.spark.util.Utils private var _classLoader: AbstractFileClassLoader = null // active classloader private val _compiler: Global = newCompiler(settings, reporter) // our private compiler + private trait ExposeAddUrl extends URLClassLoader { def addNewUrl(url: URL) = this.addURL(url) } + private var _runtimeClassLoader: URLClassLoader with ExposeAddUrl = null // wrapper exposing addURL + private val nextReqId = { var counter = 0 () => { counter += 1 ; counter } @@ -308,6 +312,57 @@ import org.apache.spark.util.Utils } } + /** + * Adds any specified jars to the compile and runtime classpaths. + * + * @note Currently only supports jars, not directories + * @param urls The list of items to add to the compile and runtime classpaths + */ + def addUrlsToClassPath(urls: URL*): Unit = { + new Run // Needed to force initialization of "something" to correctly load Scala classes from jars + urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution + updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling + } + + protected def updateCompilerClassPath(urls: URL*): Unit = { + require(!global.forMSIL) // Only support JavaPlatform + + val platform = global.platform.asInstanceOf[JavaPlatform] + + val newClassPath = mergeUrlsIntoClassPath(platform, urls: _*) + + // NOTE: Must use reflection until this is exposed/fixed upstream in Scala + val fieldSetter = platform.getClass.getMethods + .find(_.getName.endsWith("currentClassPath_$eq")).get + fieldSetter.invoke(platform, Some(newClassPath)) + + // Reload all jars specified into our compiler + global.invalidateClassPathEntries(urls.map(_.getPath): _*) + } + + protected def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { + // Collect our new jars/directories and add them to the existing set of classpaths + val allClassPaths = ( + platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ + urls.map(url => { + platform.classPath.context.newClassPath( + if (url.getProtocol == "file") { + val f = new File(url.getPath) + if (f.isDirectory) + io.AbstractFile.getDirectory(f) + else + io.AbstractFile.getFile(f) + } else { + io.AbstractFile.getURL(url) + } + ) + }) + ).distinct + + // Combine all of our classpaths (old and new) into one merged classpath + new MergedClassPath(allClassPaths, platform.classPath.context) + } + /** Parent classloader. Overridable. */ protected def parentClassLoader: ClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) @@ -356,7 +411,9 @@ import org.apache.spark.util.Utils private def makeClassLoader(): AbstractFileClassLoader = new TranslatingClassLoader(parentClassLoader match { case null => ScalaClassLoader fromURLs compilerClasspath - case p => new URLClassLoader(compilerClasspath, p) + case p => + _runtimeClassLoader = new URLClassLoader(compilerClasspath, p) with ExposeAddUrl + _runtimeClassLoader }) def getInterpreterClassLoader() = classLoader diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c0af7ceb6d3ef..3e2ee7541f40d 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import com.google.common.io.Files -import org.apache.spark.TestUtils +import org.apache.spark.{SparkConf, TestUtils} import org.apache.spark.util.Utils class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { @@ -57,7 +57,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") @@ -65,7 +65,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(url1, parentLoader, false) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, false) val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -73,7 +73,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -81,7 +81,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() } diff --git a/sbin/slaves.sh b/sbin/slaves.sh index f89547fef9e46..1d4dc5edf9858 100755 --- a/sbin/slaves.sh +++ b/sbin/slaves.sh @@ -36,29 +36,29 @@ if [ $# -le 0 ]; then exit 1 fi -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" # If the slaves file is specified in the command line, # then it takes precedence over the definition in # spark-env.sh. Save it here. -HOSTLIST=$SPARK_SLAVES +HOSTLIST="$SPARK_SLAVES" # Check if --config is passed as an argument. It is an optional parameter. # Exit if the argument is not a directory. if [ "$1" == "--config" ] then shift - conf_dir=$1 + conf_dir="$1" if [ ! -d "$conf_dir" ] then echo "ERROR : $conf_dir is not a directory" echo $usage exit 1 else - export SPARK_CONF_DIR=$conf_dir + export SPARK_CONF_DIR="$conf_dir" fi shift fi @@ -79,7 +79,7 @@ if [ "$SPARK_SSH_OPTS" = "" ]; then fi for slave in `cat "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do - ssh $SPARK_SSH_OPTS $slave $"${@// /\\ }" \ + ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ 2>&1 | sed "s/^/$slave: /" & if [ "$SPARK_SLAVE_SLEEP" != "" ]; then sleep $SPARK_SLAVE_SLEEP diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 5c87da5815b64..2718d6cba1c9a 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -21,19 +21,19 @@ # resolve links - $0 may be a softlink this="${BASH_SOURCE-$0}" -common_bin=$(cd -P -- "$(dirname -- "$this")" && pwd -P) +common_bin="$(cd -P -- "$(dirname -- "$this")" && pwd -P)" script="$(basename -- "$this")" this="$common_bin/$script" # convert relative path to absolute path -config_bin=`dirname "$this"` -script=`basename "$this"` -config_bin=`cd "$config_bin"; pwd` +config_bin="`dirname "$this"`" +script="`basename "$this"`" +config_bin="`cd "$config_bin"; pwd`" this="$config_bin/$script" -export SPARK_PREFIX=`dirname "$this"`/.. -export SPARK_HOME=${SPARK_PREFIX} +export SPARK_PREFIX="`dirname "$this"`"/.. +export SPARK_HOME="${SPARK_PREFIX}" export SPARK_CONF_DIR="$SPARK_HOME/conf" # Add the PySpark classes to the PYTHONPATH: -export PYTHONPATH=$SPARK_HOME/python:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH +export PYTHONPATH="$SPARK_HOME/python:$PYTHONPATH" +export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 323f675b17848..bd476b400e1c3 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -37,8 +37,8 @@ if [ $# -le 1 ]; then exit 1 fi -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" @@ -50,14 +50,14 @@ sbin=`cd "$sbin"; pwd` if [ "$1" == "--config" ] then shift - conf_dir=$1 + conf_dir="$1" if [ ! -d "$conf_dir" ] then echo "ERROR : $conf_dir is not a directory" echo $usage exit 1 else - export SPARK_CONF_DIR=$conf_dir + export SPARK_CONF_DIR="$conf_dir" fi shift fi @@ -100,12 +100,12 @@ if [ "$SPARK_LOG_DIR" = "" ]; then export SPARK_LOG_DIR="$SPARK_HOME/logs" fi mkdir -p "$SPARK_LOG_DIR" -touch $SPARK_LOG_DIR/.spark_test > /dev/null 2>&1 +touch "$SPARK_LOG_DIR"/.spark_test > /dev/null 2>&1 TEST_LOG_DIR=$? if [ "${TEST_LOG_DIR}" = "0" ]; then - rm -f $SPARK_LOG_DIR/.spark_test + rm -f "$SPARK_LOG_DIR"/.spark_test else - chown $SPARK_IDENT_STRING $SPARK_LOG_DIR + chown "$SPARK_IDENT_STRING" "$SPARK_LOG_DIR" fi if [ "$SPARK_PID_DIR" = "" ]; then @@ -113,10 +113,8 @@ if [ "$SPARK_PID_DIR" = "" ]; then fi # some variables -export SPARK_LOGFILE=spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.log -export SPARK_ROOT_LOGGER="INFO,DRFA" -log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out -pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid +log="$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out" +pid="$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid" # Set default scheduling priority if [ "$SPARK_NICENESS" = "" ]; then @@ -138,7 +136,7 @@ case $startStop in fi if [ "$SPARK_MASTER" != "" ]; then - echo rsync from $SPARK_MASTER + echo rsync from "$SPARK_MASTER" rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' $SPARK_MASTER/ "$SPARK_HOME" fi diff --git a/sbin/spark-executor b/sbin/spark-executor index 3621321a9bc8d..674ce906d9421 100755 --- a/sbin/spark-executor +++ b/sbin/spark-executor @@ -17,10 +17,10 @@ # limitations under the License. # -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -export PYTHONPATH=$FWDIR/python:$PYTHONPATH -export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH +export PYTHONPATH="$FWDIR/python:$PYTHONPATH" +export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" echo "Running spark-executor with framework dir = $FWDIR" -exec $FWDIR/bin/spark-class org.apache.spark.executor.MesosExecutorBackend +exec "$FWDIR"/bin/spark-class org.apache.spark.executor.MesosExecutorBackend diff --git a/sbin/start-all.sh b/sbin/start-all.sh index 5c89ab4d86b3a..1baf57cea09ee 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -21,8 +21,8 @@ # Starts the master on this node. # Starts a worker on each node specified in conf/slaves -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" TACHYON_STR="" diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh index e30493da32a7a..7172ad15d88fc 100755 --- a/sbin/start-history-server.sh +++ b/sbin/start-history-server.sh @@ -24,8 +24,11 @@ # Use the SPARK_HISTORY_OPTS environment variable to set history server configuration. # -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" +. "$SPARK_PREFIX/bin/load-spark-env.sh" if [ $# != 0 ]; then echo "Using command line arguments for setting the log directory is deprecated. Please " diff --git a/sbin/start-master.sh b/sbin/start-master.sh index c5c02491f78e1..17fff58f4f768 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -19,8 +19,8 @@ # Starts the master on the machine this script is executed on. -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" START_TACHYON=false diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index b563400dc24f3..2fc35309f4ca5 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -20,7 +20,7 @@ # Usage: start-slave.sh # where is like "spark://localhost:7077" -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" "$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@" diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 4912d0c0c7dfd..ba1a84abc1fef 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -17,8 +17,8 @@ # limitations under the License. # -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" START_TACHYON=false @@ -46,11 +46,11 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then fi if [ "$SPARK_MASTER_IP" = "" ]; then - SPARK_MASTER_IP=`hostname` + SPARK_MASTER_IP="`hostname`" fi if [ "$START_TACHYON" == "true" ]; then - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" # set -t so we can call sudo SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/../tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 @@ -58,12 +58,12 @@ fi # Launch the slaves if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT + exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" else if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then SPARK_WORKER_WEBUI_PORT=8081 fi for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i )) + "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i )) done fi diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 2c4452473ccbc..4ce40fe750384 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -24,9 +24,10 @@ set -o posix # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" +CLASS_NOT_FOUND_EXIT_STATUS=1 function usage { echo "Usage: ./sbin/start-thriftserver [options] [thrift server options]" @@ -37,42 +38,28 @@ function usage { pattern+="\|=======" pattern+="\|--help" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 echo echo "Thrift server options:" - $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 } -function ensure_arg_number { - arg_number=$1 - at_least=$2 - - if [[ $arg_number -lt $at_least ]]; then - usage - exit 1 - fi -} - -if [[ "$@" = --help ]] || [[ "$@" = -h ]]; then +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then usage exit 0 fi -THRIFT_SERVER_ARGS=() -SUBMISSION_ARGS=() +source "$FWDIR"/bin/utils.sh +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" -while (($#)); do - case $1 in - --hiveconf) - ensure_arg_number $# 2 - THRIFT_SERVER_ARGS+=("$1"); shift - THRIFT_SERVER_ARGS+=("$1"); shift - ;; +"$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" +exit_status=$? - *) - SUBMISSION_ARGS+=("$1"); shift - ;; - esac -done +if [[ exit_status -eq CLASS_NOT_FOUND_EXIT_STATUS ]]; then + echo + echo "Failed to load Hive Thrift server main class $CLASS." + echo "You need to build Spark with -Phive." +fi -exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${THRIFT_SERVER_ARGS[@]}" +exit $exit_status diff --git a/sbin/stop-all.sh b/sbin/stop-all.sh index 60b358d374565..298c6a9859795 100755 --- a/sbin/stop-all.sh +++ b/sbin/stop-all.sh @@ -21,8 +21,8 @@ # Run this on the master nde -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" # Load the Spark configuration . "$sbin/spark-config.sh" diff --git a/sbin/stop-history-server.sh b/sbin/stop-history-server.sh index c0034ad641cbe..6e6056359510f 100755 --- a/sbin/stop-history-server.sh +++ b/sbin/stop-history-server.sh @@ -19,7 +19,7 @@ # Stops the history server on the machine this script is executed on. -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.history.HistoryServer 1 diff --git a/sbt/sbt b/sbt/sbt index 1b1aa1483a829..c172fa74bc771 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -3,32 +3,32 @@ # When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so # that we can run Hive to generate the golden answer. This is not required for normal development # or testing. -for i in $HIVE_HOME/lib/* -do HADOOP_CLASSPATH=$HADOOP_CLASSPATH:$i +for i in "$HIVE_HOME"/lib/* +do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" done export HADOOP_CLASSPATH realpath () { ( - TARGET_FILE=$1 + TARGET_FILE="$1" - cd $(dirname $TARGET_FILE) - TARGET_FILE=$(basename $TARGET_FILE) + cd "$(dirname "$TARGET_FILE")" + TARGET_FILE="$(basename "$TARGET_FILE")" COUNT=0 while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] do - TARGET_FILE=$(readlink $TARGET_FILE) - cd $(dirname $TARGET_FILE) - TARGET_FILE=$(basename $TARGET_FILE) + TARGET_FILE="$(readlink "$TARGET_FILE")" + cd $(dirname "$TARGET_FILE") + TARGET_FILE="$(basename $TARGET_FILE)" COUNT=$(($COUNT + 1)) done - echo $(pwd -P)/$TARGET_FILE + echo "$(pwd -P)/"$TARGET_FILE"" ) } -. $(dirname $(realpath $0))/sbt-launch-lib.bash +. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" diff --git a/sbt/sbt-launch-lib.bash b/sbt/sbt-launch-lib.bash index c91fecf024ad4..7f05d2ef491a3 100755 --- a/sbt/sbt-launch-lib.bash +++ b/sbt/sbt-launch-lib.bash @@ -7,7 +7,7 @@ # TODO - Should we merge the main SBT script with this library? if test -z "$HOME"; then - declare -r script_dir="$(dirname $script_path)" + declare -r script_dir="$(dirname "$script_path")" else declare -r script_dir="$HOME/.sbt" fi @@ -46,20 +46,20 @@ acquire_sbt_jar () { if [[ ! -f "$sbt_jar" ]]; then # Download sbt launch jar if it hasn't been downloaded yet - if [ ! -f ${JAR} ]; then + if [ ! -f "${JAR}" ]; then # Download printf "Attempting to fetch sbt\n" - JAR_DL=${JAR}.part + JAR_DL="${JAR}.part" if hash curl 2>/dev/null; then - (curl --progress-bar ${URL1} > ${JAR_DL} || curl --progress-bar ${URL2} > ${JAR_DL}) && mv ${JAR_DL} ${JAR} + (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" elif hash wget 2>/dev/null; then - (wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR} + (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 fi fi - if [ ! -f ${JAR} ]; then + if [ ! -f "${JAR}" ]; then # We failed to download printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" exit -1 diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 830711a46a35b..0d756f873e486 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0d26b52a84695..88a8fa7c28e0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ @@ -32,6 +31,15 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) + /** Converts Scala objects to catalyst rows / types */ + def convertToCatalyst(a: Any): Any = a match { + case o: Option[_] => o.orNull + case s: Seq[_] => s.map(convertToCatalyst) + case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } + case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case other => other + } + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => @@ -62,11 +70,14 @@ object ScalaReflection { sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - Schema(ArrayType(schemaFor(elementType).dataType), nullable = true) + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Map[_,_]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala old mode 100644 new mode 100755 index 2c73a80f64ebf..ca69531c69a77 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -73,6 +73,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val ASC = Keyword("ASC") protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AVG = Keyword("AVG") + protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CACHE = Keyword("CACHE") protected val CAST = Keyword("CAST") @@ -81,6 +82,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val DISTINCT = Keyword("DISTINCT") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") + protected val LAST = Keyword("LAST") protected val FROM = Keyword("FROM") protected val FULL = Keyword("FULL") protected val GROUP = Keyword("GROUP") @@ -114,6 +116,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val STRING = Keyword("STRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") + protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") @@ -122,6 +125,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val EXCEPT = Keyword("EXCEPT") protected val SUBSTR = Keyword("SUBSTR") protected val SUBSTRING = Keyword("SUBSTRING") + protected val SQRT = Keyword("SQRT") + protected val ABS = Keyword("ABS") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -270,6 +275,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | + termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { + case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + } | termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | @@ -309,6 +317,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } | FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | + LAST ~> "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } | AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } | MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } | @@ -323,6 +332,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) } | + SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | + ABS ~> "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) } @@ -346,18 +357,27 @@ class SqlParser extends StandardTokenParsers with PackratParsers { expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | + (expression <~ ".") ~ ident ^^ { + case base ~ fieldName => GetField(base, fieldName) + } | TRUE ^^^ Literal(true, BooleanType) | FALSE ^^^ Literal(false, BooleanType) | cast | "(" ~> expression <~ ")" | function | "-" ~> literal ^^ UnaryMinus | + dotExpressionHeader | ident ^^ UnresolvedAttribute | "*" ^^^ Star(None) | literal + protected lazy val dotExpressionHeader: Parser[Expression] = + (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { + case i1 ~ i2 ~ rest => UnresolvedAttribute(i1 + "." + i2 + rest.mkString(".", ".", "")) + } + protected lazy val dataType: Parser[DataType] = - STRING ^^^ StringType + STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType } class SqlLexical(val keywords: Seq[String]) extends StdLexical { @@ -369,7 +389,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]" + ",", ";", "%", "{", "}", ":", "[", "]", "." ) override lazy val token: Parser[Token] = ( @@ -390,7 +410,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { | failure("illegal character") ) - override def identChar = letter | elem('_') | elem('.') + override def identChar = letter | elem('_') override def whitespace: Parser[Any] = rep( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c18d7858f0a43..574d96d92942b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -40,7 +40,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool // TODO: pass this in as a parameter. val fixedPoint = FixedPoint(100) - val batches: Seq[Batch] = Seq( + /** + * Override to provide additional rules for the "Resolution" batch. + */ + val extendedRules: Seq[Rule[LogicalPlan]] = Nil + + lazy val batches: Seq[Batch] = Seq( Batch("MultiInstanceRelations", Once, NewRelationInstances), Batch("CaseInsensitiveAttributeReferences", Once, @@ -54,8 +59,9 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool StarExpansion :: ResolveFunctions :: GlobalAggregates :: - UnresolvedHavingClauseAttributes :: - typeCoercionRules :_*), + UnresolvedHavingClauseAttributes :: + typeCoercionRules ++ + extendedRules : _*), Batch("Check Analysis", Once, CheckResolution), Batch("AnalysisOperators", fixedPoint, @@ -63,7 +69,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ) /** - * Makes sure all attributes have been resolved. + * Makes sure all attributes and logical plans have been resolved. */ object CheckResolution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { @@ -71,6 +77,13 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case p if p.expressions.exists(!_.resolved) => throw new TreeNodeException(p, s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}") + case p if !p.resolved && p.childrenResolved => + throw new TreeNodeException(p, "Unresolved plan found") + } match { + // As a backstop, use the root node to check that the entire plan tree is resolved. + case p if !p.resolved => + throw new TreeNodeException(p, "Unresolved plan in tree") + case p => p } } } @@ -132,7 +145,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) val resolved = unresolved.flatMap(child.resolveChildren) - val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet + val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) val missingInProject = requiredAttributes -- p.output if (missingInProject.nonEmpty) { @@ -152,8 +165,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ) logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve).toSet - val missingInAggs = resolved -- a.outputSet + val resolved = unresolved.flatMap(groupingRelation.resolve) + val missingInAggs = resolved.filterNot(a.outputSet.contains) logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") if (missingInAggs.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 15eb5982a4a91..79e5283e86a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -26,10 +26,22 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: val numericPrecedence = - Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) - // Boolean is only wider than Void - val booleanPrecedence = Seq(NullType, BooleanType) - val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) + val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil + + def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + val valueTypes = Seq(t1, t2).filter(t => t != NullType) + if (valueTypes.distinct.size > 1) { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = + HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) + + // If found return the widest common type, otherwise None + applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + } else { + Some(if (valueTypes.size == 0) NullType else valueTypes.head) + } + } } /** @@ -53,17 +65,6 @@ trait HiveTypeCoercion { Division :: Nil - trait TypeWidening { - def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) - } - } - /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -144,7 +145,8 @@ trait HiveTypeCoercion { * - LongType to FloatType * - LongType to DoubleType */ - object WidenTypes extends Rule[LogicalPlan] with TypeWidening { + object WidenTypes extends Rule[LogicalPlan] { + import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transform { case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -218,15 +220,27 @@ trait HiveTypeCoercion { case a: BinaryArithmetic if a.right.dataType == StringType => a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + case p: BinaryPredicate if p.left.dataType == StringType + && p.right.dataType == TimestampType => + p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) + case p: BinaryPredicate if p.left.dataType == TimestampType + && p.right.dataType == StringType => + p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) + case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) + case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => + i.makeCopy(Array(a,b.map(Cast(_,TimestampType)))) + case Sum(e) if e.dataType == StringType => Sum(Cast(e, DoubleType)) case Average(e) if e.dataType == StringType => Average(Cast(e, DoubleType)) + case Sqrt(e) if e.dataType == StringType => + Sqrt(Cast(e, DoubleType)) } } @@ -272,6 +286,10 @@ trait HiveTypeCoercion { // If the data type is not boolean and is being cast boolean, turn it into a comparison // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) + // Stringify boolean if casting to StringType. + // TODO Ensure true/false string letter casing is consistent with Hive in all cases. + case Cast(e, StringType) if e.dataType == BooleanType => + If(e, Literal("true"), Literal("false")) // Turn true into 1, and false into 0 if casting boolean into other types. case Cast(e, dataType) if e.dataType == BooleanType => Cast(If(e, Literal(1), Literal(0)), dataType) @@ -340,7 +358,9 @@ trait HiveTypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening { + object CaseWhenCoercion extends Rule[LogicalPlan] { + import HiveTypeCoercion._ + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) => val valueTypes = branches.sliding(2, 2).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a0e25775da6dd..a2c61c65487cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -66,7 +66,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E override def dataType = throw new UnresolvedException(this, "dataType") override def foldable = throw new UnresolvedException(this, "foldable") override def nullable = throw new UnresolvedException(this, "nullable") - override def references = children.flatMap(_.references).toSet override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala old mode 100644 new mode 100755 index f44521d6381c9..deb622c39faf5 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -132,6 +132,7 @@ package object dsl { def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) def avg(e: Expression) = Average(e) def first(e: Expression) = First(e) + def last(e: Expression) = Last(e) def min(e: Expression) = Min(e) def max(e: Expression) = Max(e) def upper(e: Expression) = Upper(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala new file mode 100644 index 0000000000000..8364379644c90 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +/** + * Builds a map that is keyed by an Attribute's expression id. Using the expression id allows values + * to be looked up even when the attributes used differ cosmetically (i.e., the capitalization + * of the name, or the expected nullability). + */ +object AttributeMap { + def apply[A](kvs: Seq[(Attribute, A)]) = + new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap) +} + +class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) + extends Map[Attribute, A] with Serializable { + + override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) + + override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = + (baseMap.map(_._2) + kv).toMap + + override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator + + override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala new file mode 100644 index 0000000000000..c3a08bbdb6bc7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +protected class AttributeEquals(val a: Attribute) { + override def hashCode() = a.exprId.hashCode() + override def equals(other: Any) = other match { + case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId + case otherAttribute => false + } +} + +object AttributeSet { + /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */ + def apply(baseSet: Seq[Attribute]) = { + new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet) + } +} + +/** + * A Set designed to hold [[AttributeReference]] objects, that performs equality checking using + * expression id instead of standard java equality. Using expression id means that these + * sets will correctly test for membership, even when the AttributeReferences in question differ + * cosmetically (e.g., the names have different capitalizations). + * + * Note that we do not override equality for Attribute references as it is really weird when + * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests, + * and also makes doing transformations hard (we always try keep older trees instead of new ones + * when the transformation was a no-op). + */ +class AttributeSet private (val baseSet: Set[AttributeEquals]) + extends Traversable[Attribute] with Serializable { + + /** Returns true if the members of this AttributeSet and other are the same. */ + override def equals(other: Any) = other match { + case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + case _ => false + } + + /** Returns true if this set contains an Attribute with the same expression id as `elem` */ + def contains(elem: NamedExpression): Boolean = + baseSet.contains(new AttributeEquals(elem.toAttribute)) + + /** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */ + def +(elem: Attribute): AttributeSet = // scalastyle:ignore + new AttributeSet(baseSet + new AttributeEquals(elem)) + + /** Returns a new [[AttributeSet]] that does not contain `elem`. */ + def -(elem: Attribute): AttributeSet = + new AttributeSet(baseSet - new AttributeEquals(elem)) + + /** Returns an iterator containing all of the attributes in the set. */ + def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator + + /** + * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in + * `other`. + */ + def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + + /** + * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found + * in `other`. + */ + def --(other: Traversable[NamedExpression]) = + new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + + /** + * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found + * in `other`. + */ + def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + + /** + * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to + * true. + */ + override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + + /** + * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in + * `this` and `other`. + */ + def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) + + override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) + + // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all + // sorts of things in its closure. + override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 0913f15888780..fa80b07f8e6be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,20 +32,26 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - override def references = Set.empty - override def toString = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } object BindReferences extends Logging { - def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { + + def bindReference[A <: Expression]( + expression: A, + input: Seq[Attribute], + allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) if (ordinal == -1) { - sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + if (allowFailures) { + a + } else { + sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + } } else { BoundReference(ordinal, a.dataType, a.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ba62dabe3dd6a..70507e7ee2be8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] { */ def foldable: Boolean = false def nullable: Boolean - def references: Set[Attribute] + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ def eval(input: Row = null): EvaluatedType @@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def foldable = left.foldable && right.foldable - override def references = left.references ++ right.references - override def toString = s"($left $symbol $right)" } @@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - override def references = child.references + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 8fc5896974438..ef1d12531f109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -27,7 +27,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) - protected val exprArray = expressions.toArray + // null check is required for when Kryo invokes the no-arg constructor. + protected val exprArray = if (expressions != null) expressions.toArray else null def apply(input: Row): Row = { val outputArray = new Array[Any](exprArray.length) @@ -109,7 +110,346 @@ class JoinedRow extends Row { def apply(i: Int) = if (i < row1.size) row1(i) else row2(i - row1.size) - def isNullAt(i: Int) = apply(i) == null + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } +} + +/** + * JIT HACK: Replace with macros + * The `JoinedRow` class is used in many performance critical situation. Unfortunately, since there + * are multiple different types of `Rows` that could be stored as `row1` and `row2` most of the + * calls in the critical path are polymorphic. By creating special versions of this class that are + * used in only a single location of the code, we increase the chance that only a single type of + * Row will be referenced, increasing the opportunity for the JIT to play tricks. This sounds + * crazy but in benchmarks it had noticeable effects. + */ +class JoinedRow2 extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } +} + +/** + * JIT HACK: Replace with macros + */ +class JoinedRow3 extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } +} + +/** + * JIT HACK: Replace with macros + */ +class JoinedRow4 extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } +} + +/** + * JIT HACK: Replace with macros + */ +class JoinedRow5 extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) def getInt(i: Int): Int = if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index 38f836f0a1a0e..851db95b9177e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType case object Rand extends LeafExpression { override def dataType = DoubleType override def nullable = false - override def references = Set.empty private[this] lazy val rand = new Random diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index c9a63e201ef60..d68a4fabeac77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -127,7 +127,7 @@ object EmptyRow extends Row { * the array is not copied, and thus could technically be mutated after creation, this is not * allowed. */ -class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { +class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ def this() = this(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 95633dd0c9870..1b687a443ef8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,16 +17,19 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.util.ClosureCleaner case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { type EvaluatedType = Any - def references = children.flatMap(_.references).toSet def nullable = true + override def toString = s"scalaUDF(${children.mkString(",")})" + /** This method has been generated by this script (1 to 22).map { x => @@ -44,7 +47,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:off override def eval(input: Row): Any = { - children.size match { + val result = children.size match { case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => @@ -343,5 +346,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi children(21).eval(input)) } // scalastyle:on + + ScalaReflection.convertToCatalyst(result) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d2b7685e73065..d00b2ac09745c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -31,7 +31,6 @@ case object Descending extends SortDirection case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { - override def references = child.references override def dataType = child.dataType override def nullable = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala new file mode 100644 index 0000000000000..088f11ee4aa53 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.types._ + +/** + * A parent class for mutable container objects that are reused when the values are changed, + * resulting in less garbage. These values are held by a [[SpecificMutableRow]]. + * + * The following code was roughly used to generate these objects: + * {{{ + * val types = "Int,Float,Boolean,Double,Short,Long,Byte,Any".split(",") + * types.map {tpe => + * s""" + * final class Mutable$tpe extends MutableValue { + * var value: $tpe = 0 + * def boxed = if (isNull) null else value + * def update(v: Any) = value = { + * isNull = false + * v.asInstanceOf[$tpe] + * } + * def copy() = { + * val newCopy = new Mutable$tpe + * newCopy.isNull = isNull + * newCopy.value = value + * newCopy.asInstanceOf[this.type] + * } + * }""" + * }.foreach(println) + * + * types.map { tpe => + * s""" + * override def set$tpe(ordinal: Int, value: $tpe): Unit = { + * val currentValue = values(ordinal).asInstanceOf[Mutable$tpe] + * currentValue.isNull = false + * currentValue.value = value + * } + * + * override def get$tpe(i: Int): $tpe = { + * values(i).asInstanceOf[Mutable$tpe].value + * }""" + * }.foreach(println) + * }}} + */ +abstract class MutableValue extends Serializable { + var isNull: Boolean = true + def boxed: Any + def update(v: Any) + def copy(): this.type +} + +final class MutableInt extends MutableValue { + var value: Int = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Int] + } + def copy() = { + val newCopy = new MutableInt + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableFloat extends MutableValue { + var value: Float = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Float] + } + def copy() = { + val newCopy = new MutableFloat + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableBoolean extends MutableValue { + var value: Boolean = false + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Boolean] + } + def copy() = { + val newCopy = new MutableBoolean + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableDouble extends MutableValue { + var value: Double = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Double] + } + def copy() = { + val newCopy = new MutableDouble + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableShort extends MutableValue { + var value: Short = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Short] + } + def copy() = { + val newCopy = new MutableShort + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableLong extends MutableValue { + var value: Long = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Long] + } + def copy() = { + val newCopy = new MutableLong + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableByte extends MutableValue { + var value: Byte = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Byte] + } + def copy() = { + val newCopy = new MutableByte + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableAny extends MutableValue { + var value: Any = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Any] + } + def copy() = { + val newCopy = new MutableAny + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +/** + * A row type that holds an array specialized container objects, of type [[MutableValue]], chosen + * based on the dataTypes of each column. The intent is to decrease garbage when modifying the + * values of primitive columns. + */ +final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { + + def this(dataTypes: Seq[DataType]) = + this( + dataTypes.map { + case IntegerType => new MutableInt + case ByteType => new MutableByte + case FloatType => new MutableFloat + case ShortType => new MutableShort + case DoubleType => new MutableDouble + case BooleanType => new MutableBoolean + case LongType => new MutableLong + case _ => new MutableAny + }.toArray) + + def this() = this(Seq.empty) + + override def length: Int = values.length + + override def setNullAt(i: Int): Unit = { + values(i).isNull = true + } + + override def apply(i: Int): Any = values(i).boxed + + override def isNullAt(i: Int): Boolean = values(i).isNull + + override def copy(): Row = { + val newValues = new Array[MutableValue](values.length) + var i = 0 + while (i < values.length) { + newValues(i) = values(i).copy() + i += 1 + } + new SpecificMutableRow(newValues) + } + + override def update(ordinal: Int, value: Any): Unit = { + if (value == null) setNullAt(ordinal) else values(ordinal).update(value) + } + + override def iterator: Iterator[Any] = values.map(_.boxed).iterator + + def setString(ordinal: Int, value: String) = update(ordinal, value) + + def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + + override def setInt(ordinal: Int, value: Int): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableInt] + currentValue.isNull = false + currentValue.value = value + } + + override def getInt(i: Int): Int = { + values(i).asInstanceOf[MutableInt].value + } + + override def setFloat(ordinal: Int, value: Float): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableFloat] + currentValue.isNull = false + currentValue.value = value + } + + override def getFloat(i: Int): Float = { + values(i).asInstanceOf[MutableFloat].value + } + + override def setBoolean(ordinal: Int, value: Boolean): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableBoolean] + currentValue.isNull = false + currentValue.value = value + } + + override def getBoolean(i: Int): Boolean = { + values(i).asInstanceOf[MutableBoolean].value + } + + override def setDouble(ordinal: Int, value: Double): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableDouble] + currentValue.isNull = false + currentValue.value = value + } + + override def getDouble(i: Int): Double = { + values(i).asInstanceOf[MutableDouble].value + } + + override def setShort(ordinal: Int, value: Short): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableShort] + currentValue.isNull = false + currentValue.value = value + } + + override def getShort(i: Int): Short = { + values(i).asInstanceOf[MutableShort].value + } + + override def setLong(ordinal: Int, value: Long): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableLong] + currentValue.isNull = false + currentValue.value = value + } + + override def getLong(i: Int): Long = { + values(i).asInstanceOf[MutableLong].value + } + + override def setByte(ordinal: Int, value: Byte): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableByte] + currentValue.isNull = false + currentValue.value = value + } + + override def getByte(i: Int): Byte = { + values(i).asInstanceOf[MutableByte].value + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index eb8898900d6a5..1eb55715794a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { type EvaluatedType = DynamicRow def nullable = false - def references = children.toSet + def dataType = DynamicType override def eval(input: Row): DynamicRow = input match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala old mode 100644 new mode 100755 index 01947273b6ccc..1b4d892625dbb --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -22,6 +22,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.util.collection.OpenHashSet abstract class AggregateExpression extends Expression { self: Product => @@ -77,7 +78,7 @@ abstract class AggregateFunction /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - override def references = base.references + override def nullable = base.nullable override def dataType = base.dataType @@ -88,7 +89,7 @@ abstract class AggregateFunction } case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = true override def dataType = child.dataType override def toString = s"MIN($child)" @@ -104,21 +105,22 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - var currentMin: Any = _ + val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) + val cmp = GreaterThan(currentMin, expr) override def update(input: Row): Unit = { - if (currentMin == null) { - currentMin = expr.eval(input) - } else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) { - currentMin = expr.eval(input) + if (currentMin.value == null) { + currentMin.value = expr.eval(input) + } else if(cmp.eval(input) == true) { + currentMin.value = expr.eval(input) } } - override def eval(input: Row): Any = currentMin + override def eval(input: Row): Any = currentMin.value } case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = true override def dataType = child.dataType override def toString = s"MAX($child)" @@ -134,21 +136,22 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - var currentMax: Any = _ + val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) + val cmp = LessThan(currentMax, expr) override def update(input: Row): Unit = { - if (currentMax == null) { - currentMax = expr.eval(input) - } else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) { - currentMax = expr.eval(input) + if (currentMax.value == null) { + currentMax.value = expr.eval(input) + } else if(cmp.eval(input) == true) { + currentMax.value = expr.eval(input) } } - override def eval(input: Row): Any = currentMax + override def eval(input: Row): Any = currentMax.value } case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"COUNT($child)" @@ -161,18 +164,91 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance() = new CountFunction(child, this) } -case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression { +case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { + def this() = this(null) + override def children = expressions - override def references = expressions.flatMap(_.references).toSet + override def nullable = false override def dataType = LongType override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" override def newInstance() = new CountDistinctFunction(expressions, this) + + override def asPartial = { + val partialSet = Alias(CollectHashSet(expressions), "partialSets")() + SplitEvaluation( + CombineSetsAndCount(partialSet.toAttribute), + partialSet :: Nil) + } +} + +case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { + def this() = this(null) + + override def children = expressions + override def nullable = false + override def dataType = ArrayType(expressions.head.dataType) + override def toString = s"AddToHashSet(${expressions.mkString(",")})" + override def newInstance() = new CollectHashSetFunction(expressions, this) +} + +case class CollectHashSetFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) + + override def update(input: Row): Unit = { + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) + } + } + + override def eval(input: Row): Any = { + seen + } +} + +case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { + def this() = this(null) + + override def children = inputSet :: Nil + override def nullable = false + override def dataType = LongType + override def toString = s"CombineAndCount($inputSet)" + override def newInstance() = new CombineSetsAndCountFunction(inputSet, this) +} + +case class CombineSetsAndCountFunction( + @transient inputSet: Expression, + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + override def update(input: Row): Unit = { + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) + } + } + + override def eval(input: Row): Any = seen.size.toLong } case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -181,7 +257,7 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -190,7 +266,7 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -208,7 +284,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = DoubleType override def toString = s"AVG($child)" @@ -228,7 +304,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"SUM($child)" @@ -246,7 +322,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"SUM(DISTINCT $child)" @@ -255,7 +331,6 @@ case class SumDistinct(child: Expression) } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references override def nullable = true override def dataType = child.dataType override def toString = s"FIRST($child)" @@ -269,6 +344,21 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance() = new FirstFunction(child, this) } +case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references = child.references + override def nullable = true + override def dataType = child.dataType + override def toString = s"LAST($child)" + + override def asPartial: SplitEvaluation = { + val partialLast = Alias(Last(child), "PartialLast")() + SplitEvaluation( + Last(partialLast.toAttribute), + partialLast :: Nil) + } + override def newInstance() = new LastFunction(child, this) +} + case class AverageFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -277,7 +367,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private val zero = Cast(Literal(0), expr.dataType) private var count: Long = _ - private val sum = MutableLiteral(zero.eval(EmptyRow)) + private val sum = MutableLiteral(zero.eval(null), expr.dataType) private val sumAsDouble = Cast(sum, DoubleType) private def addFunction(value: Any) = Add(sum, Literal(value)) @@ -350,7 +440,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val zero = Cast(Literal(0), expr.dataType) - private val sum = MutableLiteral(zero.eval(null)) + private val sum = MutableLiteral(zero.eval(null), expr.dataType) private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) @@ -379,17 +469,22 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) } -case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression) +case class CountDistinctFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - val seen = new scala.collection.mutable.HashSet[Any]() + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) override def update(input: Row): Unit = { - val evaluatedExpr = expr.map(_.eval(input)) - if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) { - seen += evaluatedExpr + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) } } @@ -409,3 +504,16 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: Row): Any = result } + +case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var result: Any = null + + override def update(input: Row): Unit = { + result = input + } + + override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row]) + else null +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c79c1847cedf5..fe825fdcdae37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.types._ +import scala.math.pow case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any @@ -33,6 +34,19 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { } } +case class Sqrt(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = DoubleType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"SQRT($child)" + + override def eval(input: Row): Any = { + n1(child, input, ((na,a) => math.sqrt(na.toDouble(a)))) + } +} + abstract class BinaryArithmetic extends BinaryExpression { self: Product => @@ -85,3 +99,48 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } + +case class MaxOf(left: Expression, right: Expression) extends Expression { + type EvaluatedType = Any + + override def foldable = left.foldable && right.foldable + + override def nullable = left.nullable && right.nullable + + override def children = left :: right :: Nil + + override def dataType = left.dataType + + override def eval(input: Row): Any = { + val leftEval = left.eval(input) + val rightEval = right.eval(input) + if (leftEval == null) { + rightEval + } else if (rightEval == null) { + leftEval + } else { + val numeric = left.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + if (numeric.compare(leftEval, rightEval) < 0) { + rightEval + } else { + leftEval + } + } + } + + override def toString = s"MaxOf($left, $right)" +} + +/** + * A function that get the absolute value of the numeric value. + */ +case class Abs(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = child.dataType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"Abs($child)" + + override def eval(input: Row): Any = n1(child, input, _.abs(_)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index de2d67ce82ff1..5a3f013c34579 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -26,6 +26,10 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ +// These classes are here to avoid issues with serialization and integration with quasiquotes. +class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] +class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] + /** * A base class for generators of byte code to perform expression evaluation. Includes a set of * helpers for referring to Catalyst types and building trees that perform evaluation of individual @@ -50,6 +54,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private val curId = new java.util.concurrent.atomic.AtomicInteger() private val javaSeparator = "$" + /** + * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. + */ + var debugLogging = false + /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -71,7 +80,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most * fundamental difference is that a ConcurrentMap persists all elements that are added to it until * they are explicitly removed. A Cache on the other hand is generally configured to evict entries - * automatically, in order to constrain its memory footprint + * automatically, in order to constrain its memory footprint. Note that this cache does not use + * weak keys/values and thus does not respond to memory pressure. */ protected val cache = CacheBuilder.newBuilder() .maximumSize(1000) @@ -403,6 +413,78 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin $primitiveTerm = ${falseEval.primitiveTerm} } """.children + + case NewSet(elementType) => + q""" + val $nullTerm = false + val $primitiveTerm = new ${hashSetForType(elementType)}() + """.children + + case AddItemToSet(item, set) => + val itemEval = expressionEvaluator(item) + val setEval = expressionEvaluator(set) + + val ArrayType(elementType, _) = set.dataType + + itemEval.code ++ setEval.code ++ + q""" + if (!${itemEval.nullTerm}) { + ${setEval.primitiveTerm} + .asInstanceOf[${hashSetForType(elementType)}] + .add(${itemEval.primitiveTerm}) + } + + val $nullTerm = false + val $primitiveTerm = ${setEval.primitiveTerm} + """.children + + case CombineSets(left, right) => + val leftEval = expressionEvaluator(left) + val rightEval = expressionEvaluator(right) + + val ArrayType(elementType, _) = left.dataType + + leftEval.code ++ rightEval.code ++ + q""" + val $nullTerm = false + var $primitiveTerm: ${hashSetForType(elementType)} = null + + { + val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] + val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] + val iterator = rightSet.iterator + while (iterator.hasNext) { + leftSet.add(iterator.next()) + } + $primitiveTerm = leftSet + } + """.children + + case MaxOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + + if (${eval1.nullTerm}) { + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} + } else if (${eval2.nullTerm}) { + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $nullTerm = false + if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $primitiveTerm = ${eval2.primitiveTerm} + } + } + """.children + } // If there was no match in the partial function above, we fall back on calling the interpreted @@ -420,7 +502,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // Only inject debugging code if debugging is turned on. val debugCode = - if (log.isDebugEnabled) { + if (debugLogging) { val localLogger = log val localLoggerTree = reify { localLogger } q""" @@ -454,6 +536,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + protected def hashSetForType(dt: DataType) = dt match { + case IntegerType => typeOf[IntegerHashSet] + case LongType => typeOf[LongHashSet] + case unsupportedType => + sys.error(s"Code generation not support for hashset of type $unsupportedType") + } + protected def primitiveForType(dt: DataType) = dt match { case IntegerType => "Int" case LongType => "Long" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 77fa02c13de30..7871a62620478 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -69,8 +69,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { ..${evaluatedExpression.code} if(${evaluatedExpression.nullTerm}) setNullAt($iLit) - else + else { + nullBits($iLit) = false $elementName = ${evaluatedExpression.primitiveTerm} + } } """.children : Seq[Tree] } @@ -106,9 +108,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if(value == null) { setNullAt(i) } else { + nullBits(i) = false $elementName = value.asInstanceOf[${termForType(e.dataType)}] - return } + return }""" } q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" @@ -137,7 +140,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? - q"if(i == $i) { $elementName = value; return }" :: Nil + q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil case _ => Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index c1154eb81c319..dafd745ec96c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { /** `Null` is returned for invalid ordinals. */ override def nullable = true override def foldable = child.foldable && ordinal.foldable - override def references = children.flatMap(_.references).toSet + def dataType = child.dataType match { case ArrayType(dt, _) => dt case MapType(_, vt, _) => vt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e99c5b452d183..9c865254e0be9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -47,8 +47,6 @@ abstract class Generator extends Expression { override def nullable = false - override def references = children.flatMap(_.references).toSet - /** * Should be overridden by specific generators. Called only once for each instance to ensure * that rule application does not change the output schema of a generator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e15e16d633365..78a0c55e4bbe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -52,7 +52,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { override def foldable = true def nullable = value == null - def references = Set.empty + override def toString = if (value != null) value.toString else "null" @@ -61,13 +61,10 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { } // TODO: Specialize -case class MutableLiteral(var value: Any, nullable: Boolean = true) extends LeafExpression { +case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) + extends LeafExpression { type EvaluatedType = Any - val dataType = Literal(value).dataType - - def references = Set.empty - def update(expression: Expression, input: Row) = { value = expression.eval(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 02d04762629f5..7c4b9d4847e26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression { def toAttribute = this def newInstance: Attribute - override def references = Set(this) + } /** @@ -85,7 +85,7 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable - override def references = child.references + override def toAttribute = { if (resolved) { @@ -116,6 +116,8 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { + override def references = AttributeSet(this :: Nil) + override def equals(other: Any) = other match { case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index e88c5d4fa178a..086d0a3e073e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -26,7 +26,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ def nullable = !children.exists(!_.nullable) - def references = children.flatMap(_.references).toSet // Coalesce is foldable if all children are foldable. override def foldable = !children.exists(!_.foldable) @@ -53,7 +52,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - def references = child.references override def foldable = child.foldable def nullable = false @@ -65,7 +63,6 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - def references = child.references override def foldable = child.foldable def nullable = false override def toString = s"IS NOT NULL $child" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5976b0ddf3e03..329af332d0fa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate { */ case class In(value: Expression, list: Seq[Expression]) extends Predicate { def children = value +: list - def references = children.flatMap(_.references).toSet + def nullable = true // TODO: Figure out correct nullability semantics of IN. override def toString = s"$value IN ${list.mkString("(", ",", ")")}" @@ -197,7 +197,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi def children = predicate :: trueValue :: falseValue :: Nil override def nullable = trueValue.nullable || falseValue.nullable - def references = children.flatMap(_.references).toSet + override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType def dataType = { if (!resolved) { @@ -239,7 +239,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen(branches: Seq[Expression]) extends Expression { type EvaluatedType = Any def children = branches - def references = children.flatMap(_.references).toSet + def dataType = { if (!resolved) { throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") @@ -265,12 +265,13 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { false } else { val allCondBooleans = predicates.forall(_.dataType == BooleanType) - val dataTypesEqual = values.map(_.dataType).distinct.size <= 1 + // both then and else val should be considered. + val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1 allCondBooleans && dataTypesEqual } } - /** Written in imperative fashion for performance considerations. Same for CaseKeyWhen. */ + /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { val len = branchesArr.length var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala new file mode 100644 index 0000000000000..3d4c4a8853c12 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.util.collection.OpenHashSet + +/** + * Creates a new set of the specified type + */ +case class NewSet(elementType: DataType) extends LeafExpression { + type EvaluatedType = Any + + def nullable = false + + // We are currently only using these Expressions internally for aggregation. However, if we ever + // expose these to users we'll want to create a proper type instead of hijacking ArrayType. + def dataType = ArrayType(elementType) + + def eval(input: Row): Any = { + new OpenHashSet[Any]() + } + + override def toString = s"new Set($dataType)" +} + +/** + * Adds an item to a set. + * For performance, this expression mutates its input during evaluation. + */ +case class AddItemToSet(item: Expression, set: Expression) extends Expression { + type EvaluatedType = Any + + def children = item :: set :: Nil + + def nullable = set.nullable + + def dataType = set.dataType + def eval(input: Row): Any = { + val itemEval = item.eval(input) + val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] + + if (itemEval != null) { + if (setEval != null) { + setEval.add(itemEval) + setEval + } else { + null + } + } else { + setEval + } + } + + override def toString = s"$set += $item" +} + +/** + * Combines the elements of two sets. + * For performance, this expression mutates its left input set during evaluation. + */ +case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { + type EvaluatedType = Any + + def nullable = left.nullable || right.nullable + + def dataType = left.dataType + + def symbol = "++=" + + def eval(input: Row): Any = { + val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] + if(leftEval != null) { + val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] + if (rightEval != null) { + val iterator = rightEval.iterator + while(iterator.hasNext) { + val rightValue = iterator.next() + leftEval.add(rightValue) + } + leftEval + } else { + null + } + } else { + null + } + } +} + +/** + * Returns the number of elements in the input set. + */ +case class CountSet(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def nullable = child.nullable + + def dataType = LongType + + def eval(input: Row): Any = { + val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] + if (childEval != null) { + childEval.size.toLong + } + } + + override def toString = s"$child.count()" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 97fc3a3b14b88..c2a3a5ca3ca8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -226,8 +226,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends if (str.dataType == BinaryType) str.dataType else StringType } - def references = children.flatMap(_.references).toSet - override def children = str :: pos :: len :: Nil @inline diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5f86d6047cb9c..a4133feae8166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,12 +40,60 @@ object Optimizer extends RuleExecutor[LogicalPlan] { SimplifyCasts, SimplifyCaseConversionExpressions) :: Batch("Filter Pushdown", FixedPoint(100), + UnionPushdown, CombineFilters, PushPredicateThroughProject, PushPredicateThroughJoin, ColumnPruning) :: Nil } +/** + * Pushes operations to either side of a Union. + */ +object UnionPushdown extends Rule[LogicalPlan] { + + /** + * Maps Attributes from the left side to the corresponding Attribute on the right side. + */ + def buildRewrites(union: Union): AttributeMap[Attribute] = { + assert(union.left.output.size == union.right.output.size) + + AttributeMap(union.left.output.zip(union.right.output)) + } + + /** + * Rewrites an expression so that it can be pushed to the right side of a Union operator. + * This method relies on the fact that the output attributes of a union are always equal + * to the left child's output. + */ + def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { + val result = e transform { + case a: Attribute => rewrites(a) + } + + // We must promise the compiler that we did not discard the names in the case of project + // expressions. This is safe since the only transformation is from Attribute => Attribute. + result.asInstanceOf[A] + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Push down filter into union + case Filter(condition, u @ Union(left, right)) => + val rewrites = buildRewrites(u) + Union( + Filter(condition, left), + Filter(pushToRight(condition, rewrites), right)) + + // Push down projection into union + case Project(projectList, u @ Union(left, right)) => + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + } +} + + /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following * transformations: @@ -65,8 +113,10 @@ object ColumnPruning extends Rule[LogicalPlan] { // Eliminate unneeded attributes from either side of a Join. case Project(projectList, Join(left, right, joinType, condition)) => // Collect the list of all references required either above or to evaluate the condition. - val allReferences: Set[Attribute] = - projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty) + val allReferences: AttributeSet = + AttributeSet( + projectList.flatMap(_.references.iterator)) ++ + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) @@ -76,8 +126,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Eliminate unneeded attributes from right side of a LeftSemiJoin. case Join(left, right, LeftSemi, condition) => // Collect the list of all references required to evaluate the condition. - val allReferences: Set[Attribute] = - condition.map(_.references).getOrElse(Set.empty) + val allReferences: AttributeSet = + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) Join(left, prunedChild(right, allReferences), LeftSemi, condition) @@ -104,7 +154,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: Set[Attribute]) = + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 90923fe31a063..f0fd9a8b9a46e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -134,8 +135,8 @@ object PartialAggregation { // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[Long, SplitEvaluation] = - partialAggregates.map(a => (a.id, a.asPartial)).toMap + val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] = + partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be @@ -148,8 +149,8 @@ object PartialAggregation { // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(e.id) => - partialEvaluations(e.id).finalEvaluation + case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => + partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression if namedGroupingExpressions.contains(e) => namedGroupingExpressions(e).toAttribute }).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0988b0c6d990c..af9e4d86e995a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType} @@ -29,7 +29,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Returns the set of attributes that are output by this node. */ - def outputSet: Set[Attribute] = output.toSet + def outputSet: AttributeSet = AttributeSet(output) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. @@ -50,11 +50,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy @inline def transformExpressionDown(e: Expression) = { val newE = e.transformDown(rule) - if (newE.id != e.id && newE != e) { + if (newE.fastEquals(e)) { + e + } else { changed = true newE - } else { - e } } @@ -82,11 +82,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy @inline def transformExpressionUp(e: Expression) = { val newE = e.transformUp(rule) - if (newE.id != e.id && newE != e) { + if (newE.fastEquals(e)) { + e + } else { changed = true newE - } else { - e } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 278569f0cb14a..ede431ad4ab27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -41,25 +41,24 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { case class Statistics( sizeInBytes: BigInt ) - lazy val statistics: Statistics = Statistics( - sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product - ) + lazy val statistics: Statistics = { + if (children.size == 0) { + throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") + } - /** - * Returns the set of attributes that are referenced by this node - * during evaluation. - */ - def references: Set[Attribute] + Statistics( + sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) + } /** * Returns the set of attributes that this node takes as * input from its children. */ - lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet + lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output)) /** * Returns true if this expression and all its children have been resolved to a specific schema - * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan + * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan * can override this (e.g. * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]] * should return `false`). @@ -105,11 +104,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - a.dataType match { - case StructType(fields) => - Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) - case _ => None // Don't know how to resolve these field references - } + Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) case Seq() => None // No matches. case ambiguousReferences => throw new TreeNodeException( @@ -123,12 +118,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { */ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => - - override lazy val statistics: Statistics = - throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") - - // Leaf nodes by definition cannot reference any input attributes. - override def references = Set.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index d3f9d0fb93237..4460c86ed9026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -30,6 +30,4 @@ case class ScriptTransformation( input: Seq[Expression], script: String, output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - def references = input.flatMap(_.references).toSet -} + child: LogicalPlan) extends UnaryNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 3cb407217c4c3..5d10754c7b028 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { def output = projectList.map(_.toAttribute) - def references = projectList.flatMap(_.references).toSet } /** @@ -59,14 +58,10 @@ case class Generate( override def output = if (join) child.output ++ generatorOutput else generatorOutput - - override def references = - if (join) child.outputSet else generator.references } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = condition.references } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -76,8 +71,6 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } - - override def references = Set.empty } case class Join( @@ -86,8 +79,6 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def references = condition.map(_.references).getOrElse(Set.empty) - override def output = { joinType match { case LeftSemi => @@ -106,8 +97,6 @@ case class Join( case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { def output = left.output - - def references = Set.empty } case class InsertIntoTable( @@ -118,7 +107,6 @@ case class InsertIntoTable( extends LogicalPlan { // The table being inserted into is a child for the purposes of transformations. override def children = table :: child :: Nil - override def references = Set.empty override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { @@ -126,24 +114,22 @@ case class InsertIntoTable( } } -case class InsertIntoCreatedTable( +case class CreateTableAsSelect( databaseName: Option[String], tableName: String, child: LogicalPlan) extends UnaryNode { - override def references = Set.empty override def output = child.output + override lazy val resolved = (databaseName != None && childrenResolved) } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - override def references = Set.empty override def output = child.output } case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = order.flatMap(_.references).toSet } case class Aggregate( @@ -152,19 +138,20 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { + /** The set of all AttributeReferences required for this aggregation. */ + def references = + AttributeSet( + groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references)) + override def output = aggregateExpressions.map(_.toAttribute) - override def references = - (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = limitExpr.references } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output = child.output.map(_.withQualifiers(alias :: Nil)) - override def references = Set.empty } /** @@ -191,20 +178,16 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { a.qualifiers) case other => other } - - override def references = Set.empty } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = Set.empty } case class Distinct(child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = child.outputSet } case object NoRelation extends LeafNode { @@ -213,5 +196,4 @@ case object NoRelation extends LeafNode { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output = left.output - override def references = Set.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 481a5a4f212b2..a01809c1fc5e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -50,7 +50,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman * Returned by a parser when the users only wants to see what query plan would be executed, without * actually performing the execution. */ -case class ExplainCommand(plan: LogicalPlan) extends Command { +case class ExplainCommand(plan: LogicalPlan, extended: Boolean = false) extends Command { override def output = Seq(AttributeReference("plan", StringType, nullable = false)()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 7146fbd540f29..72b0c5c8e7a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -31,13 +31,9 @@ abstract class RedistributeData extends UnaryNode { case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData { - - def references = sortExpressions.flatMap(_.references).toSet } case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) extends RedistributeData { - - def references = partitionExpressions.flatMap(_.references).toSet } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4bb022cf238af..ccb0df113c063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -71,6 +71,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + // TODO: This is not really valid... def clustering = ordering.map(_.child).toSet } @@ -139,7 +140,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) with Partitioning { override def children = expressions - override def references = expressions.flatMap(_.references).toSet override def nullable = false override def dataType = IntegerType @@ -179,7 +179,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) with Partitioning { override def children = ordering - override def references = ordering.flatMap(_.references).toSet override def nullable = false override def dataType = IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cd04bdf02cf84..2013ae4f7bd13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,11 +19,6 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ -object TreeNode { - private val currentId = new java.util.concurrent.atomic.AtomicLong - protected def nextId() = currentId.getAndIncrement() -} - /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -33,29 +28,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { /** Returns a Seq of the children of this node */ def children: Seq[BaseType] - /** - * A globally unique id for this specific instance. Not preserved across copies. - * Unlike `equals`, `id` can be used to differentiate distinct but structurally - * identical branches of a tree. - */ - val id = TreeNode.nextId() - - /** - * Returns true if other is the same [[catalyst.trees.TreeNode TreeNode]] instance. Unlike - * `equals` this function will return false for different instances of structurally identical - * trees. - */ - def sameInstance(other: TreeNode[_]): Boolean = { - this.id == other.id - } - /** * Faster version of equality which short-circuits when two treeNodes are the same instance. * We don't just override Object.Equals, as doing so prevents the scala compiler from from * generating case class `equals` methods */ def fastEquals(other: TreeNode[_]): Boolean = { - sameInstance(other) || this == other + this.eq(other) || this == other } /** @@ -280,7 +259,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { try { - val defaultCtor = getClass.getConstructors.head + // Skip no-arg constructors that are just there for kryo. + val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head if (otherCopyArgs.isEmpty) { defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] } else { @@ -392,3 +372,4 @@ trait UnaryNode[BaseType <: TreeNode[BaseType]] { def child: BaseType def children = child :: Nil } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index d725a92c06f7b..79a8e06d4b4d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -37,4 +37,15 @@ package object trees extends Logging { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. protected override def logName = "catalyst.trees" + /** + * A [[TreeNode]] companion for reference equality for Hash based Collection. + */ + class TreeNodeRef(val obj: TreeNode[_]) { + override def equals(o: Any) = o match { + case that: TreeNodeRef => that.obj.eq(obj) + case _ => false + } + + override def hashCode = if (obj == null) 0 else obj.hashCode + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index b52ee6d3378a3..70c6d06cf2534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -270,8 +270,8 @@ case object FloatType extends FractionalType { } object ArrayType { - /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */ - def apply(elementType: DataType): ArrayType = ArrayType(elementType, false) + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ + def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index e75373d5a74a7..428607d8c8253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -57,7 +57,9 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], - mapField: Map[Int, String], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: Map[Int, Long], + mapFieldValueContainsNull: Map[Int, java.lang.Long], structField: PrimitiveData) case class GenericData[A]( @@ -116,8 +118,22 @@ class ScalaReflectionSuite extends FunSuite { val schema = schemaFor[ComplexData] assert(schema === Schema( StructType(Seq( - StructField("arrayField", ArrayType(IntegerType), nullable = true), - StructField("mapField", MapType(IntegerType, StringType), nullable = true), + StructField( + "arrayField", + ArrayType(IntegerType, containsNull = false), + nullable = true), + StructField( + "arrayFieldContainsNull", + ArrayType(IntegerType, containsNull = true), + nullable = true), + StructField( + "mapField", + MapType(IntegerType, LongType, valueContainsNull = false), + nullable = true), + StructField( + "mapFieldValueContainsNull", + MapType(IntegerType, LongType, valueContainsNull = true), + nullable = true), StructField( "structField", StructType(Seq( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0a4fde3de7752..5809a108ff62e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -93,6 +93,17 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val e = intercept[TreeNodeException[_]] { caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) } - assert(e.getMessage().toLowerCase.contains("unresolved")) + assert(e.getMessage().toLowerCase.contains("unresolved attribute")) + } + + test("throw errors for unresolved plans during analysis") { + case class UnresolvedTestPlan() extends LeafNode { + override lazy val resolved = false + override def output = Nil + } + val e = intercept[TreeNodeException[_]] { + caseSensitiveAnalyze(UnresolvedTestPlan()) + } + assert(e.getMessage().toLowerCase.contains("unresolved plan")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index b9e0f8e9dcc5f..baeb9b0cf5964 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -19,24 +19,26 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.types._ class HiveTypeCoercionSuite extends FunSuite { - val rules = new HiveTypeCoercion { } - import rules._ - - test("tightest common bound for numeric and boolean types") { + test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = WidenTypes.findTightestCommonType(t1, t2) + var found = HiveTypeCoercion.findTightestCommonType(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = WidenTypes.findTightestCommonType(t2, t1) + found = HiveTypeCoercion.findTightestCommonType(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } + // Null + widenTest(NullType, NullType, Some(NullType)) + // Boolean widenTest(NullType, BooleanType, Some(BooleanType)) widenTest(BooleanType, BooleanType, Some(BooleanType)) @@ -60,12 +62,41 @@ class HiveTypeCoercionSuite extends FunSuite { widenTest(DoubleType, DoubleType, Some(DoubleType)) // Integral mixed with floating point. - widenTest(NullType, FloatType, Some(FloatType)) - widenTest(NullType, DoubleType, Some(DoubleType)) widenTest(IntegerType, FloatType, Some(FloatType)) widenTest(IntegerType, DoubleType, Some(DoubleType)) widenTest(IntegerType, DoubleType, Some(DoubleType)) widenTest(LongType, FloatType, Some(FloatType)) widenTest(LongType, DoubleType, Some(DoubleType)) + + // StringType + widenTest(NullType, StringType, Some(StringType)) + widenTest(StringType, StringType, Some(StringType)) + widenTest(IntegerType, StringType, None) + widenTest(LongType, StringType, None) + + // TimestampType + widenTest(NullType, TimestampType, Some(TimestampType)) + widenTest(TimestampType, TimestampType, Some(TimestampType)) + widenTest(IntegerType, TimestampType, None) + widenTest(StringType, TimestampType, None) + + // ComplexType + widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false))) + widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) + widenTest(StringType, MapType(IntegerType, StringType, true), None) + widenTest(ArrayType(IntegerType), StructType(Seq()), None) + } + + test("boolean casts") { + val booleanCasts = new HiveTypeCoercion { }.BooleanCasts + def ruleTest(initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) == + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + // Remove superflous boolean -> boolean casts. + ruleTest(Cast(Literal(true), BooleanType), Literal(true)) + // Stringify boolean when casting to string. + ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false"))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 999c9fff38d60..b961346dfc995 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -136,6 +136,16 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) } + test("MaxOf") { + checkEvaluation(MaxOf(1, 2), 2) + checkEvaluation(MaxOf(2, 1), 2) + checkEvaluation(MaxOf(1L, 2L), 2L) + checkEvaluation(MaxOf(2L, 1L), 2L) + + checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal(null, StringType).like("a"), null) checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) @@ -567,4 +577,17 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(s.substring(0, 2), "ex", row) checkEvaluation(s.substring(0), "example", row) } + + test("SQRT") { + val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) + val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) + val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val d = 'a.double.at(0) + + for ((row, expected) <- rowSequence zip expectedResults) { + checkEvaluation(Sqrt(d), expected, row) + } + + checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala new file mode 100644 index 0000000000000..dfef87bd9133d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class UnionPushdownSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateAnalysisOperators) :: + Batch("Union Pushdown", Once, + UnionPushdown) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testUnion = Union(testRelation, testRelation2) + + test("union: filter to each side") { + val query = testUnion.where('a === 1) + + val optimized = Optimize(query.analyze) + + val correctAnswer = + Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("union: project to each side") { + val query = testUnion.select('b) + + val optimized = Optimize(query.analyze) + + val correctAnswer = + Union(testRelation.select('b), testRelation2.select('e)).analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 6344874538d67..036fd3fa1d6a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.types.{StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { def children = optKey.toSeq - def references = Set.empty[Attribute] def nullable = true def dataType = NullType override lazy val resolved = true @@ -52,7 +51,10 @@ class TreeNodeSuite extends FunSuite { val after = before transform { case Literal(5, _) => Literal(1)} assert(before === after) - assert(before.map(_.id) === after.map(_.id)) + // Ensure that the objects after are the same objects before the transformation. + before.map(identity[Expression]).zip(after.map(identity[Expression])).foreach { + case (b, a) => assert(b eq a) + } } test("collect") { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c8016e41256d5..bd110218d34f7 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 3eccddef88134..37b4c8ffcba0b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -86,14 +86,14 @@ public abstract class DataType { /** * Creates an ArrayType by specifying the data type of elements ({@code elementType}). - * The field of {@code containsNull} is set to {@code false}. + * The field of {@code containsNull} is set to {@code true}. */ public static ArrayType createArrayType(DataType elementType) { if (elementType == null) { throw new IllegalArgumentException("elementType should not be null."); } - return new ArrayType(elementType, false); + return new ArrayType(elementType, true); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5cc41a83cc792..f6f4cf3b80d41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -26,6 +26,7 @@ import java.util.Properties private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" + val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" @@ -33,6 +34,7 @@ private[spark] object SQLConf { val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" + val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -51,7 +53,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -trait SQLConf { +private[sql] trait SQLConf { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -78,6 +80,9 @@ trait SQLConf { /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + /** The compression codec for writing to a Parquetfile */ + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "snappy") + /** The number of rows that will be */ private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt @@ -88,7 +93,7 @@ trait SQLConf { * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster * than interpreted evaluation, but there are significant start-up costs due to compilation. - * As a result codegen is only benificial when queries run for a long time, or when the same + * As a result codegen is only beneficial when queries run for a long time, or when the same * expressions are used multiple times. * * Defaults to false as this feature is currently experimental. @@ -107,8 +112,9 @@ trait SQLConf { /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, - * it is set to a larger value than `autoConvertJoinSize`, hence any logical operator without a - * properly implemented estimation of this statistic will not be incorrectly broadcasted in joins. + * it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator + * without a properly implemented estimation of this statistic will not be incorrectly broadcasted + * in joins. */ private[spark] def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong @@ -119,6 +125,12 @@ trait SQLConf { private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** + * When set to true, partition pruning for in-memory columnar tables is enabled. + */ + private[spark] def inMemoryPartitionPruning: Boolean = + getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index af9f7c62a1d25..c551c7c9877e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -89,8 +89,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = + implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { + SparkPlan.currentContext.set(self) new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) + } /** * :: DeveloperApi :: @@ -244,7 +246,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(None, tableName, rdd.logicalPlan) + catalog.registerTable(None, tableName, rdd.queryExecution.analyzed) } /** @@ -270,7 +272,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val currentTable = table(tableName).queryExecution.analyzed val asInMemoryRelation = currentTable match { case _: InMemoryRelation => - currentTable.logicalPlan + currentTable case _ => InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) @@ -344,8 +346,8 @@ class SQLContext(@transient val sparkContext: SparkContext) prunePushedDownFilters: Seq[Expression] => Seq[Expression], scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - val projectSet = projectList.flatMap(_.references).toSet - val filterSet = filterPredicates.flatMap(_.references).toSet + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And) // Right now we still use a projection even if the only evaluation is applying an alias @@ -354,7 +356,8 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO: Decouple final output schema from expression evaluation so this copy can be // avoided safely. - if (projectList.toSet == projectSet && filterSet.subsetOf(projectSet)) { + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { // When it is possible to just use column pruning to get the right projection and // when the columns of this projection are enough to evaluate all filter conditions, // just do a scan followed by a filter, with no extra project. @@ -408,10 +411,18 @@ class SQLContext(@transient val sparkContext: SparkContext) protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } - def simpleString: String = stringOrError(executedPlan) + def simpleString: String = + s"""== Physical Plan == + |${stringOrError(executedPlan)} + """ override def toString: String = - s"""== Logical Plan == + // TODO previously will output RDD details by run (${stringOrError(toRdd.toDebugString)}) + // however, the `toRdd` will cause the real execution, which is not what we want. + // We need to think about how to avoid the side effect. + s"""== Parsed Logical Plan == + |${stringOrError(logical)} + |== Analyzed Logical Plan == |${stringOrError(analyzed)} |== Optimized Logical Plan == |${stringOrError(optimizedPlan)} @@ -419,7 +430,6 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(executedPlan)} |Code Generation: ${executedPlan.codegenEnabled} |== RDD == - |${stringOrError(toRdd.toDebugString)} """.stripMargin.trim } @@ -450,7 +460,6 @@ class SQLContext(@transient val sparkContext: SparkContext) rdd: RDD[Array[Any]], schema: StructType): SchemaRDD = { import scala.collection.JavaConversions._ - import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true @@ -472,8 +481,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case (null, _) => null case (c: java.util.List[_], ArrayType(elementType, _)) => - val converted = c.map { e => convert(e, elementType)} - JListWrapper(converted) + c.map { e => convert(e, elementType)}: Seq[Any] case (c, ArrayType(elementType, _)) if c.getClass.isArray => c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 33b2ed1b3a399..d2ceb4a2b0b25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -428,7 +428,8 @@ class SchemaRDD( */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { new SchemaRDD(sqlContext, - SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))(sqlContext)) + SparkLogicalPlan( + ExistingRdd(queryExecution.analyzed.output.map(_.newInstance), rdd))(sqlContext)) } // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 2f3033a5f94f0..e52eeb3e1c47e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -54,7 +54,7 @@ private[sql] trait SchemaRDDLike { @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { // For various commands (like DDL) and queries with side effects, we force query optimization to // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => + case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile => queryExecution.toRdd SparkLogicalPlan(queryExecution.executedPlan)(sqlContext) case _ => @@ -124,7 +124,7 @@ private[sql] trait SchemaRDDLike { */ @Experimental def saveAsTable(tableName: String): Unit = - sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd + sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan)).toRdd /** Returns the schema as a string in the tree format. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 0b48e9e659faa..595b4aa36eae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} import org.apache.spark.Accumulator +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} import org.apache.spark.sql.execution.PythonUDF @@ -29,7 +30,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} /** * Functions for registering scala lambda functions as UDFs in a SQLContext. */ -protected[sql] trait UDFRegistration { +private[sql] trait UDFRegistration { self: SQLContext => private[spark] def registerPython( @@ -38,6 +39,7 @@ protected[sql] trait UDFRegistration { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { log.debug( @@ -61,6 +63,7 @@ protected[sql] trait UDFRegistration { envVars, pythonIncludes, pythonExec, + broadcastVars, accumulator, dataType, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 6c67934bda5b8..e9d04ce7aae4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -25,7 +25,7 @@ import scala.math.BigDecimal import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** - * A result row from a SparkSQL query. + * A result row from a Spark SQL query. */ class Row(private[spark] val row: ScalaRow) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 7e7bb2859bbcd..b3ec5ded22422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -38,7 +38,7 @@ private[sql] trait ColumnBuilder { /** * Column statistics information */ - def columnStats: ColumnStats[_, _] + def columnStats: ColumnStats /** * Returns the final columnar byte buffer. @@ -47,7 +47,7 @@ private[sql] trait ColumnBuilder { } private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( - val columnStats: ColumnStats[T, JvmType], + val columnStats: ColumnStats, val columnType: ColumnType[T, JvmType]) extends ColumnBuilder { @@ -75,25 +75,24 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( } override def build() = { - buffer.limit(buffer.position()).rewind() - buffer + buffer.flip().asInstanceOf[ByteBuffer] } } private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: NativeType]( - override val columnStats: NativeColumnStats[T], + override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) +private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN) private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) @@ -129,7 +128,6 @@ private[sql] object ColumnBuilder { val newSize = capacity + size.max(capacity / 8 + 1) val pos = orig.position() - orig.clear() ByteBuffer .allocate(newSize) .order(ByteOrder.nativeOrder()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 6502110e903fe..fc343ccb995c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,381 +17,193 @@ package org.apache.spark.sql.columnar +import java.sql.Timestamp + import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.types._ +private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { + val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)() + val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)() + val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() + + val schema = Seq(lowerBound, upperBound, nullCount) +} + +private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { + val (forAttribute, schema) = { + val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) + (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) + } +} + /** * Used to collect statistical information when building in-memory columns. * * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable { - /** - * Closed lower bound of this column. - */ - def lowerBound: JvmType - - /** - * Closed upper bound of this column. - */ - def upperBound: JvmType - +private[sql] sealed trait ColumnStats extends Serializable { /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: Row, ordinal: Int) - - /** - * Returns `true` if `lower <= row(ordinal) <= upper`. - */ - def contains(row: Row, ordinal: Int): Boolean + def gatherStats(row: Row, ordinal: Int): Unit /** - * Returns `true` if `row(ordinal) < upper` holds. + * Column statistics represented as a single row, currently including closed lower bound, closed + * upper bound and null count. */ - def isAbove(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `lower < row(ordinal)` holds. - */ - def isBelow(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `row(ordinal) <= upper` holds. - */ - def isAtOrAbove(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `lower <= row(ordinal)` holds. - */ - def isAtOrBelow(row: Row, ordinal: Int): Boolean -} - -private[sql] sealed abstract class NativeColumnStats[T <: NativeType] - extends ColumnStats[T, T#JvmType] { - - type JvmType = T#JvmType - - protected var (_lower, _upper) = initialBounds - - def initialBounds: (JvmType, JvmType) - - protected def columnType: NativeColumnType[T] - - override def lowerBound: T#JvmType = _lower - - override def upperBound: T#JvmType = _upper - - override def isAtOrAbove(row: Row, ordinal: Int) = { - contains(row, ordinal) || isAbove(row, ordinal) - } - - override def isAtOrBelow(row: Row, ordinal: Int) = { - contains(row, ordinal) || isBelow(row, ordinal) - } + def collectedStatistics: Row } -private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] { - override def isAtOrBelow(row: Row, ordinal: Int) = true - - override def isAtOrAbove(row: Row, ordinal: Int) = true - - override def isBelow(row: Row, ordinal: Int) = true - - override def isAbove(row: Row, ordinal: Int) = true +private[sql] class NoopColumnStats extends ColumnStats { - override def contains(row: Row, ordinal: Int) = true + override def gatherStats(row: Row, ordinal: Int): Unit = {} - override def gatherStats(row: Row, ordinal: Int) {} - - override def upperBound = null.asInstanceOf[JvmType] - - override def lowerBound = null.asInstanceOf[JvmType] + override def collectedStatistics = Row() } -private[sql] abstract class BasicColumnStats[T <: NativeType]( - protected val columnType: NativeColumnType[T]) - extends NativeColumnStats[T] - -private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) { - override def initialBounds = (true, false) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class ByteColumnStats extends ColumnStats { + var upper = Byte.MinValue + var lower = Byte.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) { - override def initialBounds = (Byte.MaxValue, Byte.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound + if (!row.isNullAt(ordinal)) { + val value = row.getByte(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) { - override def initialBounds = (Short.MaxValue, Short.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class ShortColumnStats extends ColumnStats { + var upper = Short.MinValue + var lower = Short.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class LongColumnStats extends BasicColumnStats(LONG) { - override def initialBounds = (Long.MaxValue, Long.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound + if (!row.isNullAt(ordinal)) { + val value = row.getShort(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) { - override def initialBounds = (Double.MaxValue, Double.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class LongColumnStats extends ColumnStats { + var upper = Long.MinValue + var lower = Long.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) { - override def initialBounds = (Float.MaxValue, Float.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getLong(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } + def collectedStatistics = Row(lower, upper, nullCount) +} - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class DoubleColumnStats extends ColumnStats { + var upper = Double.MinValue + var lower = Double.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field + if (!row.isNullAt(ordinal)) { + val value = row.getDouble(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } -} -private[sql] object IntColumnStats { - val UNINITIALIZED = 0 - val INITIALIZED = 1 - val ASCENDING = 2 - val DESCENDING = 3 - val UNORDERED = 4 + def collectedStatistics = Row(lower, upper, nullCount) } -/** - * Statistical information for `Int` columns. More information is collected since `Int` is - * frequently used. Extra information include: - * - * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search - * is applicable when searching elements. - * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression - * scheme. - * - * (This two kinds of information are not used anywhere yet and might be removed later.) - */ -private[sql] class IntColumnStats extends BasicColumnStats(INT) { - import IntColumnStats._ - - private var orderedState = UNINITIALIZED - private var lastValue: Int = _ - private var _maxDelta: Int = _ - - def isAscending = orderedState != DESCENDING && orderedState != UNORDERED - def isDescending = orderedState != ASCENDING && orderedState != UNORDERED - def isOrdered = isAscending || isDescending - def maxDelta = _maxDelta - - override def initialBounds = (Int.MaxValue, Int.MinValue) +private[sql] class FloatColumnStats extends ColumnStats { + var upper = Float.MinValue + var lower = Float.MaxValue + var nullCount = 0 - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) + override def gatherStats(row: Row, ordinal: Int) { + if (!row.isNullAt(ordinal)) { + val value = row.getFloat(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } + def collectedStatistics = Row(lower, upper, nullCount) +} - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class IntColumnStats extends ColumnStats { + var upper = Int.MinValue + var lower = Int.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - - orderedState = orderedState match { - case UNINITIALIZED => - lastValue = field - INITIALIZED - - case INITIALIZED => - // If all the integers in the column are the same, ordered state is set to Ascending. - // TODO (lian) Confirm whether this is the standard behaviour. - val nextState = if (field >= lastValue) ASCENDING else DESCENDING - _maxDelta = math.abs(field - lastValue) - lastValue = field - nextState - - case ASCENDING if field < lastValue => - UNORDERED - - case DESCENDING if field > lastValue => - UNORDERED - - case state @ (ASCENDING | DESCENDING) => - _maxDelta = _maxDelta.max(field - lastValue) - lastValue = field - state - - case _ => - orderedState + if (!row.isNullAt(ordinal)) { + val value = row.getInt(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 } } + + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class StringColumnStats extends BasicColumnStats(STRING) { - override def initialBounds = (null, null) +private[sql] class StringColumnStats extends ColumnStats { + var upper: String = null + var lower: String = null + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field - if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field - } - - override def contains(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 - } - } - - override def isAbove(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - field.compareTo(upperBound) < 0 + if (!row.isNullAt(ordinal)) { + val value = row.getString(ordinal) + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 } } - override def isBelow(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) < 0 - } - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class TimestampColumnStats extends BasicColumnStats(TIMESTAMP) { - override def initialBounds = (null, null) +private[sql] class TimestampColumnStats extends ColumnStats { + var upper: Timestamp = null + var lower: Timestamp = null + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field - if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field - } - - override def contains(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Timestamp] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 } } - override def isAbove(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - field.compareTo(upperBound) < 0 - } - } - - override def isBelow(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) < 0 - } - } + def collectedStatistics = Row(lower, upper, nullCount) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 794bc60d0e315..9a61600115872 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -158,9 +158,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { buffer.put(if (v) 1.toByte else 0.toByte) } - override def extract(buffer: ByteBuffer) = { - if (buffer.get() == 1) true else false - } + override def extract(buffer: ByteBuffer) = buffer.get() == 1 override def setField(row: MutableRow, ordinal: Int, value: Boolean) { row.setBoolean(ordinal, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index e63b4903041f6..6eab2f23c18e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,32 +19,41 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{LeafNode, SparkPlan} -object InMemoryRelation { +private[sql] object InMemoryRelation { def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = new InMemoryRelation(child.output, useCompression, batchSize, child)() } +private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) + private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, batchSize: Int, child: SparkPlan) - (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) + (private var _cachedColumnBuffers: RDD[CachedBatch] = null) extends LogicalPlan with MultiInstanceRelation { + override lazy val statistics = + Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) + + val partitionStatistics = new PartitionStatistics(output) + // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output val cached = child.execute().mapPartitions { baseIterator => - new Iterator[Array[ByteBuffer]] { + new Iterator[CachedBatch] { def next() = { val columnBuilders = output.map { attribute => val columnType = ColumnType(attribute.dataType) @@ -65,7 +74,10 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - columnBuilders.map(_.build()) + val stats = Row.fromSeq( + columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) + + CachedBatch(columnBuilders.map(_.build()), stats) } def hasNext = baseIterator.hasNext @@ -76,11 +88,8 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers = cached } - override def children = Seq.empty - override def references = Set.empty - override def newInstance() = { new InMemoryRelation( output.map(_.newInstance), @@ -95,48 +104,146 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], + predicates: Seq[Expression], relation: InMemoryRelation) extends LeafNode { + @transient override val sqlContext = relation.child.sqlContext + override def output: Seq[Attribute] = attributes - override def execute() = { - relation.cachedColumnBuffers.mapPartitions { iterator => - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = - if (attributes.isEmpty) { - Seq(0) - } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) - } + // Returned filter predicate should return false iff it is impossible for the input expression + // to evaluate to `true' based on statistics collected about this partition batch. + val buildFilter: PartialFunction[Expression, Expression] = { + case And(lhs: Expression, rhs: Expression) + if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => + buildFilter(lhs) && buildFilter(rhs) - new Iterator[Row] { - private[this] var columnBuffers: Array[ByteBuffer] = null - private[this] var columnAccessors: Seq[ColumnAccessor] = null - nextBatch() + case Or(lhs: Expression, rhs: Expression) + if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => + buildFilter(lhs) || buildFilter(rhs) - private[this] val nextRow = new GenericMutableRow(columnAccessors.length) + case EqualTo(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l && l <= aStats.upperBound - def nextBatch() = { - columnBuffers = iterator.next() - columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) - } + case EqualTo(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l && l <= aStats.upperBound + + case LessThan(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound < l + + case LessThan(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + l < aStats.upperBound + + case LessThanOrEqual(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l + + case LessThanOrEqual(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + l <= aStats.upperBound + + case GreaterThan(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + l < aStats.upperBound + + case GreaterThan(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound < l + + case GreaterThanOrEqual(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + l <= aStats.upperBound - override def next() = { - if (!columnAccessors.head.hasNext) { - nextBatch() + case GreaterThanOrEqual(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l + } + + val partitionFilters = { + predicates.flatMap { p => + val filter = buildFilter.lift(p) + val boundFilter = + filter.map( + BindReferences.bindReference( + _, + relation.partitionStatistics.schema, + allowFailures = true)) + + boundFilter.foreach(_ => + filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f"))) + + // If the filter can't be resolved then we are missing required statistics. + boundFilter.filter(_.resolved) + } + } + + val readPartitions = sparkContext.accumulator(0) + val readBatches = sparkContext.accumulator(0) + + private val inMemoryPartitionPruningEnabled = sqlContext.inMemoryPartitionPruning + + override def execute() = { + readPartitions.setValue(0) + readBatches.setValue(0) + + relation.cachedColumnBuffers.mapPartitions { iterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + relation.partitionStatistics.schema) + + // Find the ordinals of the requested columns. If none are requested, use the first. + val requestedColumns = if (attributes.isEmpty) { + Seq(0) + } else { + attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + } + + val rows = iterator + // Skip pruned batches + .filter { cachedBatch => + if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) { + def statsString = relation.partitionStatistics.schema + .zip(cachedBatch.stats) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") + logInfo(s"Skipping partition based on stats $statsString") + false + } else { + readBatches += 1 + true } + } + // Build column accessors + .map { cachedBatch => + requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + } + // Extract rows via column accessors + .flatMap { columnAccessors => + val nextRow = new GenericMutableRow(columnAccessors.length) + new Iterator[Row] { + override def next() = { + var i = 0 + while (i < nextRow.length) { + columnAccessors(i).extractTo(nextRow, i) + i += 1 + } + nextRow + } - var i = 0 - while (i < nextRow.length) { - columnAccessors(i).extractTo(nextRow, i) - i += 1 + override def hasNext = columnAccessors.head.hasNext } - nextRow } - override def hasNext = columnAccessors.head.hasNext || iterator.hasNext + if (rows.hasNext) { + readPartitions += 1 } + + rows } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index f631ee76fcd78..a72970eef7aa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -49,6 +49,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { } abstract override def appendFrom(row: Row, ordinal: Int) { + columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) nulls.putInt(pos) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 463a1d32d7fd7..be9f155253d77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -175,7 +175,7 @@ case class Aggregate( private[this] val resultProjection = new InterpretedMutableProjection( resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow + private[this] val joinedRow = new JoinedRow4 override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 77dc2ad733215..927f40063e47e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree @@ -35,18 +36,26 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def outputPartitioning = newPartitioning - def output = child.output + override def output = child.output - def execute() = attachTree(this , "execute") { + /** We must copy rows when sort based shuffle is on */ + protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + + override def execute() = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. - val rdd = child.execute().mapPartitions { iter => - @transient val hashExpressions = - newMutableProjection(expressions, child.output)() - - val mutablePair = new MutablePair[Row, Row]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) + val rdd = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => + val hashExpressions = newProjection(expressions, child.output) + iter.map(r => (hashExpressions(r), r.copy())) + } + } else { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } } val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) @@ -54,13 +63,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => + val rdd = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} + } else { + child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Row, Null](null, null) + iter.map(row => mutablePair.update(row, null)) + } + } + // TODO: RangePartitioner should take an Ordering. implicit val ordering = new RowOrdering(sortingExpressions, child.output) - val rdd = child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Row, Null](null, null) - iter.map(row => mutablePair.update(row, null)) - } val part = new RangePartitioner(numPartitions, rdd, ascending = true) val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) @@ -68,9 +82,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.map(_._1) case SinglePartition => - val rdd = child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Null, Row]() - iter.map(r => mutablePair.update(null, r)) + val rdd = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) } + } else { + child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Null, Row]() + iter.map(r => mutablePair.update(null, r)) + } } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 4a26934c49c93..b3edd5020fa8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.types._ @@ -103,13 +104,48 @@ case class GeneratedAggregate( updateCount :: updateSum :: Nil, result ) + + case m @ Max(expr) => + val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() + val initialValue = Literal(null, expr.dataType) + val updateMax = MaxOf(currentMax, expr) + + AggregateEvaluation( + currentMax :: Nil, + initialValue :: Nil, + updateMax :: Nil, + currentMax) + + case CollectHashSet(Seq(expr)) => + val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)() + val initialValue = NewSet(expr.dataType) + val addToSet = AddItemToSet(expr, set) + + AggregateEvaluation( + set :: Nil, + initialValue :: Nil, + addToSet :: Nil, + set) + + case CombineSetsAndCount(inputSet) => + val ArrayType(inputType, _) = inputSet.dataType + val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() + val initialValue = NewSet(inputType) + val collectSets = CombineSets(set, inputSet) + + AggregateEvaluation( + set :: Nil, + initialValue :: Nil, + collectSets :: Nil, + CountSet(set)) } val computationSchema = computeFunctions.flatMap(_.schema) - val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map { - case (agg, func) => agg.id -> func.result - }.toMap + val resultMap: Map[TreeNodeRef, Expression] = + aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => new TreeNodeRef(agg) -> func.result + }.toMap val namedGroups = groupingExpressions.zipWithIndex.map { case (ne: NamedExpression, _) => (ne, ne) @@ -122,7 +158,7 @@ case class GeneratedAggregate( // The set of expressions that produce the final output given the aggregation buffer and the // grouping expressions. val resultExpressions = aggregateExpressions.map(_.transform { - case e: Expression if resultMap.contains(e.id) => resultMap(e.id) + case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) case e: Expression if groupMap.contains(e) => groupMap(e) }) @@ -151,7 +187,7 @@ case class GeneratedAggregate( (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - val joinedRow = new JoinedRow + val joinedRow = new JoinedRow3 if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 21cbbc9772a00..2b8913985b028 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -49,7 +49,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected val sqlContext = SparkPlan.currentContext.get() + protected[spark] val sqlContext = SparkPlan.currentContext.get() protected def sparkContext = sqlContext.sparkContext @@ -141,10 +141,9 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ extends LogicalPlan with MultiInstanceRelation { def output = alreadyPlanned.output - override def references = Set.empty override def children = Nil - override final def newInstance: this.type = { + override final def newInstance(): this.type = { SparkLogicalPlan( alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 34654447a5f4b..077e6ebc5f11e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -28,9 +28,13 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} + private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { val kryo = new Kryo() @@ -41,6 +45,13 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], new HyperLogLogSerializer) kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) + + // Specific hashsets must come first TODO: Move to core. + kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) + kryo.register(classOf[LongHashSet], new LongHashSetSerializer) + kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], + new OpenHashSetSerializer) + kryo.setReferences(false) kryo.setClassLoader(Utils.getSparkClassLoader) new AllScalaRegistrar().apply(kryo) @@ -109,3 +120,78 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] { HyperLogLog.Builder.build(bytes) } } + +private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { + def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) { + val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] + output.writeInt(hs.size) + val iterator = hs.iterator + while(iterator.hasNext) { + val row = iterator.next() + rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values) + } + } + + def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = { + val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] + val numItems = input.readInt() + val set = new OpenHashSet[Any](numItems + 1) + var i = 0 + while (i < numItems) { + val row = + new GenericRow(rowSerializer.read( + kryo, + input, + classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) + set.add(row) + i += 1 + } + set + } +} + +private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] { + def write(kryo: Kryo, output: Output, hs: IntegerHashSet) { + output.writeInt(hs.size) + val iterator = hs.iterator + while(iterator.hasNext) { + val value: Int = iterator.next() + output.writeInt(value) + } + } + + def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = { + val numItems = input.readInt() + val set = new IntegerHashSet + var i = 0 + while (i < numItems) { + val value = input.readInt() + set.add(value) + i += 1 + } + set + } +} + +private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] { + def write(kryo: Kryo, output: Output, hs: LongHashSet) { + output.writeInt(hs.size) + val iterator = hs.iterator + while(iterator.hasNext) { + val value = iterator.next() + output.writeLong(value) + } + } + + def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = { + val numItems = input.readInt() + val set = new LongHashSet + var i = 0 + while (i < numItems) { + val value = input.readLong() + set.add(value) + i += 1 + } + set + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f0c958fdb537f..7943d6e1b6fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ @@ -148,7 +149,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { - case _: Sum | _: Count => false + case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + // The generated set implementation is pretty limited ATM. + case CollectHashSet(exprs) if exprs.size == 1 && + Seq(IntegerType, LongType).contains(exprs.head.dataType) => false case _ => true } @@ -239,8 +243,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { pruneFilterProject( projectList, filters, - identity[Seq[Expression]], // No filters are pushed down. - InMemoryColumnarTableScan(_, mem)) :: Nil + identity[Seq[Expression]], // All filters still need to be evaluated. + InMemoryColumnarTableScan(_, filters, mem)) :: Nil case _ => Nil } } @@ -297,8 +301,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.SetCommand(key, value) => Seq(execution.SetCommand(key, value, plan.output)(context)) - case logical.ExplainCommand(logicalPlan) => - Seq(execution.ExplainCommand(logicalPlan, plan.output)(context)) + case logical.ExplainCommand(logicalPlan, extended) => + Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f9dfa3c92f1eb..cac376608be29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.{HashPartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} -import org.apache.spark.sql.SQLContext +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, SinglePartition, UnspecifiedDistribution} import org.apache.spark.util.MutablePair /** @@ -96,7 +96,11 @@ case class Limit(limit: Int, child: SparkPlan) // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again + /** We must copy rows when sort based shuffle is on */ + private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + override def output = child.output + override def outputPartitioning = SinglePartition /** * A custom implementation modeled after the take function on RDDs but which never runs any job @@ -143,9 +147,15 @@ case class Limit(limit: Int, child: SparkPlan) } override def execute() = { - val rdd = child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Boolean, Row]() - iter.take(limit).map(row => mutablePair.update(false, row)) + val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => + iter.take(limit).map(row => (false, row.copy())) + } + } else { + child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Boolean, Row]() + iter.take(limit).map(row => mutablePair.update(false, row)) + } } val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) @@ -164,6 +174,7 @@ case class Limit(limit: Int, child: SparkPlan) case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output = child.output + override def outputPartitioning = SinglePartition val ordering = new RowOrdering(sortOrder, child.output) @@ -204,13 +215,6 @@ case class Sort( */ @DeveloperApi object ExistingRdd { - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.orNull - case s: Seq[Any] => s.map(convertToCatalyst) - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { @@ -222,7 +226,7 @@ object ExistingRdd { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = convertToCatalyst(r.productElement(i)) + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) i += 1 } @@ -244,6 +248,7 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } + /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 38f37564f1788..94543fc95b470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,11 +21,13 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.{Row, SQLConf, SQLContext} trait Command { + this: SparkPlan => + /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field @@ -35,7 +37,11 @@ trait Command { * The `execute()` method of all the physical command classes should reference `sideEffectResult` * so that the command can be executed eagerly right after the command query is created. */ - protected[sql] lazy val sideEffectResult: Seq[Any] = Seq.empty[Any] + protected[sql] lazy val sideEffectResult: Seq[Row] = Seq.empty[Row] + + override def executeCollect(): Array[Row] = sideEffectResult.toArray + + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } /** @@ -47,17 +53,17 @@ case class SetCommand( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { + override protected[sql] lazy val sideEffectResult: Seq[Row] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) - Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")) } else { context.setConf(k, v) - Array(s"$k=$v") + Seq(Row(s"$k=$v")) } // Query the value bound to key k. @@ -72,29 +78,31 @@ case class SetCommand( "hive-hwi-0.12.0.jar", "hive-0.12.0.jar").mkString(":") - Array( - "system:java.class.path=" + hiveJars, - "system:sun.java.command=shark.SharkServer2") - } - else { - Array(s"$k=${context.getConf(k, "")}") + context.getAllConfs.map { case (k, v) => + Row(s"$k=$v") + }.toSeq ++ Seq( + Row("system:java.class.path=" + hiveJars), + Row("system:sun.java.command=shark.SharkServer2")) + } else { + if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) + } else { + Seq(Row(s"$k=${context.getConf(k, "")}")) + } } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => context.getAllConfs.map { case (k, v) => - s"$k=$v" + Row(s"$k=$v") }.toSeq case _ => throw new IllegalArgumentException() } - def execute(): RDD[Row] = { - val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) } - context.sparkContext.parallelize(rows, 1) - } - override def otherCopyArgs = context :: Nil } @@ -108,20 +116,19 @@ case class SetCommand( */ @DeveloperApi case class ExplainCommand( - logicalPlan: LogicalPlan, output: Seq[Attribute])( + logicalPlan: LogicalPlan, output: Seq[Attribute], extended: Boolean)( @transient context: SQLContext) extends LeafNode with Command { // Run through the optimizer to generate the physical plan. - override protected[sql] lazy val sideEffectResult: Seq[String] = try { - "Physical execution plan:" +: context.executePlan(logicalPlan).executedPlan.toString.split("\n") - } catch { case cause: TreeNodeException[_] => - "Error occurred during query planning: " +: cause.getMessage.split("\n") - } + override protected[sql] lazy val sideEffectResult: Seq[Row] = try { + // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. + val queryExecution = context.executePlan(logicalPlan) + val outputString = if (extended) queryExecution.toString else queryExecution.simpleString - def execute(): RDD[Row] = { - val explanation = sideEffectResult.map(row => new GenericRow(Array[Any](row))) - context.sparkContext.parallelize(explanation, 1) + outputString.split("\n").map(Row(_)) + } catch { case cause: TreeNodeException[_] => + ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) } override def otherCopyArgs = context :: Nil @@ -140,12 +147,7 @@ case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: } else { context.uncacheTable(tableName) } - Seq.empty[Any] - } - - override def execute(): RDD[Row] = { - sideEffectResult - context.emptyResult + Seq.empty[Row] } override def output: Seq[Attribute] = Seq.empty @@ -159,15 +161,8 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { - Seq(("# Registered as a temporary table", null, null)) ++ - child.output.map(field => (field.name, field.dataType.toString, null)) - } - - override def execute(): RDD[Row] = { - val rows = sideEffectResult.map { - case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) - } - context.sparkContext.parallelize(rows, 1) + override protected[sql] lazy val sideEffectResult: Seq[Row] = { + Row("# Registered as a temporary table", null, null) +: + child.output.map(field => Row(field.name, field.dataType.toString, null)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index f31df051824d7..a9535a750bcd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -23,6 +23,7 @@ import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.catalyst.trees.TreeNodeRef /** * :: DeveloperApi :: @@ -43,10 +44,10 @@ package object debug { implicit class DebugQuery(query: SchemaRDD) { def debug(): Unit = { val plan = query.queryExecution.executedPlan - val visited = new collection.mutable.HashSet[Long]() + val visited = new collection.mutable.HashSet[TreeNodeRef]() val debugPlan = plan transform { - case s: SparkPlan if !visited.contains(s.id) => - visited += s.id + case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => + visited += new TreeNodeRef(s) DebugNode(s) } println(s"Results returned: ${debugPlan.execute().count()}") @@ -58,8 +59,6 @@ package object debug { } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { - def references = Set.empty - def output = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { @@ -75,22 +74,22 @@ package object debug { } /** - * A collection of stats for each column of output. + * A collection of metrics for each column of output. * @param elementTypes the actual runtime types for the output. Useful when there are bugs * causing the wrong data to be projected. */ - case class ColumnStat( + case class ColumnMetrics( elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) val tupleCount = sparkContext.accumulator[Int](0) val numColumns = child.output.size - val columnStats = Array.fill(child.output.size)(new ColumnStat()) + val columnStats = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { println(s"== ${child.simpleString} ==") println(s"Tuples output: ${tupleCount.value}") - child.output.zip(columnStats).foreach { case(attr, stat) => - val actualDataTypes =stat.elementTypes.value.mkString("{", ",", "}") + child.output.zip(columnStats).foreach { case(attr, metric) => + val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") println(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index b08f9aacc1fcb..2890a563bed48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -92,7 +92,7 @@ trait HashJoin { private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. - private[this] val joinRow = new JoinedRow + private[this] val joinRow = new JoinedRow2 private[this] val joinKeys = streamSideKeyGenerator() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index b92091b560b1c..0977da3e8577c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -42,6 +42,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { @@ -49,7 +50,6 @@ private[spark] case class PythonUDF( override def toString = s"PythonUDF#$name(${children.mkString(",")})" def nullable: Boolean = true - def references: Set[Attribute] = children.flatMap(_.references).toSet override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") } @@ -99,7 +99,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { logical.Project( l.output, l.transformExpressions { - case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute }.withNewChildren(newChildren)) } } @@ -113,7 +113,6 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() - def references = Set.empty def output = child.output :+ resultAttribute } @@ -147,7 +146,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: udf.pythonIncludes, false, udf.pythonExec, - Seq[Broadcast[Array[Byte]]](), + udf.broadcastVars, udf.accumulator ).mapPartitions { iter => val pickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 1c0b03c684f10..873221835daf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -68,8 +68,15 @@ private[sql] object JsonRDD extends Logging { val (topLevel, structLike) = values.partition(_.size == 1) val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { - case ArrayType(StructType(Nil), _) => false - case ArrayType(_, _) => true + case ArrayType(elementType, _) => { + def hasInnerStruct(t: DataType): Boolean = t match { + case s: StructType => false + case ArrayType(t1, _) => hasInnerStruct(t1) + case o => true + } + + hasInnerStruct(elementType) + } case struct: StructType => false case _ => true } @@ -84,7 +91,18 @@ private[sql] object JsonRDD extends Logging { val dataType = resolved.get(prefix :+ name).get dataType match { case array: ArrayType => - Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true)) + // The pattern of this array is ArrayType(...(ArrayType(StructType))). + // Since the inner struct of array is a placeholder (StructType(Nil)), + // we need to replace this placeholder with the actual StructType (structType). + def getActualArrayType( + innerStruct: StructType, + currentArray: ArrayType): ArrayType = currentArray match { + case ArrayType(s: StructType, containsNull) => + ArrayType(innerStruct, containsNull) + case ArrayType(a: ArrayType, containsNull) => + ArrayType(getActualArrayType(innerStruct, a), containsNull) + } + Some(StructField(name, getActualArrayType(structType, array), nullable = true)) case struct: StructType => Some(StructField(name, structType, nullable = true)) // dataType is StringType means that we have resolved type conflicts involving // primitive types and complex types. So, the type of name has been relaxed to @@ -125,38 +143,31 @@ private[sql] object JsonRDD extends Logging { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p - .contains(t2)) - - // If found return the widest common type, otherwise None - val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last) - - if (returnType.isDefined) { - returnType.get - } else { - // t1 or t2 is a StructType, ArrayType, or an unexpected type. - (t1, t2) match { - case (other: DataType, NullType) => other - case (NullType, other: DataType) => other - case (StructType(fields1), StructType(fields2)) => { - val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { - case (name, fieldTypes) => { - val dataType = fieldTypes.map(field => field.dataType).reduce( - (type1: DataType, type2: DataType) => compatibleType(type1, type2)) - StructField(name, dataType, true) + HiveTypeCoercion.findTightestCommonType(t1, t2) match { + case Some(commonType) => commonType + case None => + // t1 or t2 is a StructType, ArrayType, or an unexpected type. + (t1, t2) match { + case (other: DataType, NullType) => other + case (NullType, other: DataType) => other + case (StructType(fields1), StructType(fields2)) => { + val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { + case (name, fieldTypes) => { + val dataType = fieldTypes.map(field => field.dataType).reduce( + (type1: DataType, type2: DataType) => compatibleType(type1, type2)) + StructField(name, dataType, true) + } } + StructType(newFields.toSeq.sortBy { + case StructField(name, _, _) => name + }) } - StructType(newFields.toSeq.sortBy { - case StructField(name, _, _) => name - }) + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + // TODO: We should use JsonObjectStringType to mark that values of field will be + // strings and every string is a Json object. + case (_, _) => StringType } - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) - // TODO: We should use JsonObjectStringType to mark that values of field will be - // strings and every string is a Json object. - case (_, _) => StringType - } } } @@ -175,8 +186,7 @@ private[sql] object JsonRDD extends Logging { /** * Returns the element type of an JSON array. We go through all elements of this array * to detect any possible type conflict. We use [[compatibleType]] to resolve - * type conflicts. Right now, when the element of an array is another array, we - * treat the element as String. + * type conflicts. */ private def typeOfArray(l: Seq[Any]): ArrayType = { val containsNull = l.exists(v => v == null) @@ -223,18 +233,24 @@ private[sql] object JsonRDD extends Logging { } case (key: String, array: Seq[_]) => { // The value associated with the key is an array. - typeOfArray(array) match { + // Handle inner structs of an array. + def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { case ArrayType(StructType(Nil), containsNull) => { // The elements of this arrays are structs. - array.asInstanceOf[Seq[Map[String, Any]]].flatMap { + v.asInstanceOf[Seq[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { - case (k, dataType) => (s"$key.$k", dataType) - } :+ (key, ArrayType(StructType(Nil), containsNull)) + case (k, t) => (s"$key.$k", t) + } } - case ArrayType(elementType, containsNull) => - (key, ArrayType(elementType, containsNull)) :: Nil + case ArrayType(t1, containsNull) => + v.asInstanceOf[Seq[Any]].flatMap { + element => buildKeyPathForInnerStructs(element, t1) + } + case other => Nil } + val elementType = typeOfArray(array) + buildKeyPathForInnerStructs(array, elementType) :+ (key, elementType) } case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil } @@ -346,8 +362,6 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case ArrayType(elementType, _) => - value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case StringType => toString(value) case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) @@ -355,6 +369,10 @@ private[sql] object JsonRDD extends Logging { case DecimalType => toDecimal(value) case BooleanType => value.asInstanceOf[BooleanType.JvmType] case NullType => null + + case ArrayType(elementType, _) => + value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) } } } @@ -363,22 +381,9 @@ private[sql] object JsonRDD extends Logging { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { - // StructType - case (StructField(name, fields: StructType, _), i) => - row.update(i, json.get(name).flatMap(v => Option(v)).map( - v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull) - - // ArrayType(StructType) - case (StructField(name, ArrayType(structType: StructType, _), _), i) => - row.update(i, - json.get(name).flatMap(v => Option(v)).map( - v => v.asInstanceOf[Seq[Any]].map( - e => asRow(e.asInstanceOf[Map[String, Any]], structType))).orNull) - - // Other cases case (StructField(name, dataType, _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( - enforceCorrectType(_, dataType)).getOrElse(null)) + enforceCorrectType(_, dataType)).orNull) } row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 0a3b59cbc233a..2fc7e1cf23ab7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -23,7 +23,7 @@ import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} import parquet.schema.MessageType import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType /** @@ -58,6 +58,7 @@ private[sql] object CatalystConverter { // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). // Note that "array" for the array elements is chosen by ParquetAvro. // Using a different value will result in Parquet silently dropping columns. + val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" val ARRAY_ELEMENTS_SCHEMA_NAME = "array" val MAP_KEY_SCHEMA_NAME = "key" val MAP_VALUE_SCHEMA_NAME = "value" @@ -82,6 +83,9 @@ private[sql] object CatalystConverter { case ArrayType(elementType: DataType, false) => { new CatalystArrayConverter(elementType, fieldIndex, parent) } + case ArrayType(elementType: DataType, true) => { + new CatalystArrayContainsNullConverter(elementType, fieldIndex, parent) + } case StructType(fields: Seq[StructField]) => { new CatalystStructConverter(fields.toArray, fieldIndex, parent) } @@ -278,14 +282,14 @@ private[parquet] class CatalystGroupConverter( */ private[parquet] class CatalystPrimitiveRowConverter( protected[parquet] val schema: Array[FieldType], - protected[parquet] var current: ParquetRelation.RowType) + protected[parquet] var current: MutableRow) extends CatalystConverter { // This constructor is used for the root converter only def this(attributes: Array[Attribute]) = this( attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), - new ParquetRelation.RowType(attributes.length)) + new SpecificMutableRow(attributes.map(_.dataType))) protected [parquet] val converters: Array[Converter] = schema.zipWithIndex.map { @@ -299,7 +303,7 @@ private[parquet] class CatalystPrimitiveRowConverter( override val parent = null // Should be only called in root group converter! - override def getCurrentRecord: ParquetRelation.RowType = current + override def getCurrentRecord: Row = current override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) @@ -378,7 +382,7 @@ private[parquet] class CatalystPrimitiveConverter( parent.updateLong(fieldIndex, value) } -object CatalystArrayConverter { +private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } @@ -567,6 +571,85 @@ private[parquet] class CatalystNativeArrayConverter( } } +/** + * A `parquet.io.api.GroupConverter` that converts a single-element groups that + * match the characteristics of an array contains null (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * + * @param elementType The type of the array elements (complex or primitive) + * @param index The position of this (array) field inside its parent converter + * @param parent The parent converter + * @param buffer A data buffer + */ +private[parquet] class CatalystArrayContainsNullConverter( + val elementType: DataType, + val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var buffer: Buffer[Any]) + extends CatalystConverter { + + def this(elementType: DataType, index: Int, parent: CatalystConverter) = + this( + elementType, + index, + parent, + new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) + + protected[parquet] val converter: Converter = new CatalystConverter { + + private var current: Any = null + + val converter = CatalystConverter.createConverter( + new CatalystConverter.FieldType( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + elementType, + false), + fieldIndex = 0, + parent = this) + + override def getConverter(fieldIndex: Int): Converter = converter + + override def end(): Unit = parent.updateField(index, current) + + override def start(): Unit = { + current = null + } + + override protected[parquet] val size: Int = 1 + override protected[parquet] val index: Int = 0 + override protected[parquet] val parent = CatalystArrayContainsNullConverter.this + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + current = value + } + + override protected[parquet] def clearBuffer(): Unit = {} + } + + override def getConverter(fieldIndex: Int): Converter = converter + + // arrays have only one (repeated) field, which is its elements + override val size = 1 + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + buffer += value + } + + override protected[parquet] def clearBuffer(): Unit = { + buffer.clear() + } + + override def start(): Unit = {} + + override def end(): Unit = { + assert(parent != null) + // here we need to make sure to use ArrayScalaType + parent.updateField(index, buffer.toArray.toSeq) + clearBuffer() + } +} + /** * This converter is for multi-element groups of primitive or complex types * that have repetition level optional or required (so struct fields). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 2298a9b933df5..7c83f1cad7d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import java.nio.ByteBuffer + import org.apache.hadoop.conf.Configuration import parquet.filter._ @@ -25,12 +27,13 @@ import parquet.column.ColumnReader import com.google.common.io.BaseEncoding +import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -object ParquetFilters { +private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" // set this to false if pushdown should be disabled val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown" @@ -237,7 +240,8 @@ object ParquetFilters { */ def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = { if (filters.length > 0) { - val serialized: Array[Byte] = SparkSqlSerializer.serialize(filters) + val serialized: Array[Byte] = + SparkEnv.get.closureSerializer.newInstance().serialize(filters).array() val encoded: String = BaseEncoding.base64().encode(serialized) conf.set(PARQUET_FILTER_DATA, encoded) } @@ -252,7 +256,7 @@ object ParquetFilters { val data = conf.get(PARQUET_FILTER_DATA) if (data != null) { val decoded: Array[Byte] = BaseEncoding.base64().decode(data) - SparkSqlSerializer.deserialize(decoded) + SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(decoded)) } else { Seq() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 1713ae6fb5d93..5ae768293a22e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -100,8 +100,13 @@ private[sql] object ParquetRelation { // The compression type type CompressionType = parquet.hadoop.metadata.CompressionCodecName - // The default compression - val defaultCompression = CompressionCodecName.GZIP + // The parquet compression short names + val shortParquetCompressionCodecNames = Map( + "NONE" -> CompressionCodecName.UNCOMPRESSED, + "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, + "SNAPPY" -> CompressionCodecName.SNAPPY, + "GZIP" -> CompressionCodecName.GZIP, + "LZO" -> CompressionCodecName.LZO) /** * Creates a new ParquetRelation and underlying Parquetfile for the given LogicalPlan. Note that @@ -141,9 +146,8 @@ private[sql] object ParquetRelation { conf: Configuration, sqlContext: SQLContext): ParquetRelation = { val path = checkPath(pathString, allowExisting, conf) - if (conf.get(ParquetOutputFormat.COMPRESSION) == null) { - conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name()) - } + conf.set(ParquetOutputFormat.COMPRESSION, shortParquetCompressionCodecNames.getOrElse( + sqlContext.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) new ParquetRelation(path.toString, Some(conf), sqlContext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index f6cfab736d98a..a5a5d139a65cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -139,7 +139,7 @@ case class ParquetTableScan( partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) new Iterator[Row] { - private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null) + private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null) def hasNext = iter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 6a657c20fe46c..bdf02401b21be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -173,7 +173,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - case t @ ArrayType(_, false) => writeArray( + case t @ ArrayType(_, _) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) case t @ MapType(_, _, _) => writeMap( @@ -228,45 +228,57 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } } - // TODO: support null values, see - // https://issues.apache.org/jira/browse/SPARK-1649 private[parquet] def writeArray( schema: ArrayType, array: CatalystConverter.ArrayScalaType[_]): Unit = { val elementType = schema.elementType writer.startGroup() if (array.size > 0) { - writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - var i = 0 - while(i < array.size) { - writeValue(elementType, array(i)) - i = i + 1 + if (schema.containsNull) { + writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) + var i = 0 + while (i < array.size) { + writer.startGroup() + if (array(i) != null) { + writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + writeValue(elementType, array(i)) + writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + } + writer.endGroup() + i = i + 1 + } + writer.endField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) + } else { + writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + var i = 0 + while (i < array.size) { + writeValue(elementType, array(i)) + i = i + 1 + } + writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } - writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } writer.endGroup() } - // TODO: support null values, see - // https://issues.apache.org/jira/browse/SPARK-1649 private[parquet] def writeMap( schema: MapType, map: CatalystConverter.MapScalaType[_, _]): Unit = { writer.startGroup() if (map.size > 0) { writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) - writer.startGroup() - writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) - for(key <- map.keys) { + for ((key, value) <- map) { + writer.startGroup() + writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) writeValue(schema.keyType, key) + writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + if (value != null) { + writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + writeValue(schema.valueType, value) + writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + } + writer.endGroup() } - writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) - writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) - for(value <- map.values) { - writeValue(schema.valueType, value) - } - writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) - writer.endGroup() writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) } writer.endGroup() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index c79a9ac2dad81..2941b9793597f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -119,7 +119,13 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { + val bag = field.asGroupType() + assert(bag.getFieldCount == 1) + ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) + } else { + ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + } } case ParquetOriginalType.MAP => { assert( @@ -129,28 +135,32 @@ private[parquet] object ParquetTypesConverter extends Logging { assert( keyValueGroup.getFieldCount == 2, "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) + + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) - assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true - // at here. - MapType(keyType, valueType) + MapType(keyType, valueType, + keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) } case _ => { // Note: the order of these checks is important! if (correspondsToMap(groupType)) { // MapType val keyValueGroup = groupType.getFields.apply(0).asGroupType() - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) + + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) - assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true - // at here. - MapType(keyType, valueType) + MapType(keyType, valueType, + keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) } else if (correspondsToArray(groupType)) { // ArrayType - val elementType = toDataType(groupType.getFields.apply(0), isBinaryAsString) - ArrayType(elementType, containsNull = false) + val field = groupType.getFields.apply(0) + if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { + val bag = field.asGroupType() + assert(bag.getFieldCount == 1) + ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) + } else { + ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + } } else { // everything else: StructType val fields = groupType .getFields @@ -249,13 +259,27 @@ private[parquet] object ParquetTypesConverter extends Logging { inArray = true) ConversionPatterns.listType(repetition, name, parquetElementType) } + case ArrayType(elementType, true) => { + val parquetElementType = fromDataType( + elementType, + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + nullable = true, + inArray = false) + ConversionPatterns.listType( + repetition, + name, + new ParquetGroupType( + Repetition.REPEATED, + CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, + parquetElementType)) + } case StructType(structFields) => { val fields = structFields.map { field => fromDataType(field.dataType, field.name, field.nullable, inArray = false) } new ParquetGroupType(repetition, name, fields) } - case MapType(keyType, valueType, _) => { + case MapType(keyType, valueType, valueContainsNull) => { val parquetKeyType = fromDataType( keyType, @@ -266,7 +290,7 @@ private[parquet] object ParquetTypesConverter extends Logging { fromDataType( valueType, CatalystConverter.MAP_VALUE_SCHEMA_NAME, - nullable = false, + nullable = valueContainsNull, inArray = false) ConversionPatterns.mapType( repetition, @@ -370,17 +394,14 @@ private[parquet] object ParquetTypesConverter extends Logging { throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") } val path = origPath.makeQualified(fs) - if (!fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException( - s"Expected $path for be a directory with Parquet files/metadata") - } - ParquetRelation.enableLogForwarding() val children = fs.listStatus(path).filterNot { status => val name = status.getPath.getName - name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME + (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE } + ParquetRelation.enableLogForwarding() + // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row // groups. Since Parquet schema is replicated among all row groups, we only need to touch a // single row group to read schema related metadata. Notice that we are making assumptions that diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index cf7d79f42db1d..8fb59c5830f6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -24,7 +24,7 @@ class DataTypeSuite extends FunSuite { test("construct an ArrayType") { val array = ArrayType(StringType) - assert(ArrayType(StringType, false) === array) + assert(ArrayType(StringType, true) === array) } test("construct an MapType") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 1a6a6c17473a3..d001abb7e1fcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test._ /* Implicits */ @@ -133,6 +135,18 @@ class DslQuerySuite extends QueryTest { mapData.take(1).toSeq) } + test("SPARK-3395 limit distinct") { + val filtered = TestData.testData2 + .distinct() + .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending)) + .limit(1) + .registerTempTable("onerow") + checkAnswer( + sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), + (1, 1, 1, 1) :: + (1, 1, 1, 2) :: Nil) + } + test("average") { checkAnswer( testData2.groupBy()(avg('a)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 651cb735ab7d9..811319e0a6601 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} class RowSuite extends FunSuite { @@ -43,4 +43,10 @@ class RowSuite extends FunSuite { assert(expected.getBoolean(2) === actual2.getBoolean(2)) assert(expected(3) === actual2(3)) } + + test("SpecificMutableRow.update with null") { + val row = new SpecificMutableRow(Seq(IntegerType)) + row(0) = null + assert(row.isNullAt(0)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9b2a36d33fca7..67563b6c55f4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,23 +17,70 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test._ +import org.scalatest.BeforeAndAfterAll +import java.util.TimeZone /* Implicits */ import TestSQLContext._ import TestData._ -class SQLQuerySuite extends QueryTest { +class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData + var origZone: TimeZone = _ + override protected def beforeAll() { + origZone = TimeZone.getDefault + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + } + + override protected def afterAll() { + TimeZone.setDefault(origZone) + } + + + test("SPARK-3176 Added Parser of SQL ABS()") { + checkAnswer( + sql("SELECT ABS(-1.3)"), + 1.3) + checkAnswer( + sql("SELECT ABS(0.0)"), + 0.0) + checkAnswer( + sql("SELECT ABS(2.5)"), + 2.5) + } + + test("SPARK-3176 Added Parser of SQL LAST()") { + checkAnswer( + sql("SELECT LAST(n) FROM lowerCaseData"), + 4) + } + + test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), "test") } + test("SQRT") { + checkAnswer( + sql("SELECT SQRT(key) FROM testData"), + (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq + ) + } + + test("SQRT with automatic string casts") { + checkAnswer( + sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"), + (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq + ) + } + test("SPARK-2407 Added Parser of SQL SUBSTR()") { checkAnswer( sql("SELECT substr(tableName, 1, 2) FROM tableName"), @@ -49,6 +96,34 @@ class SQLQuerySuite extends QueryTest { "st") } + test("SPARK-3173 Timestamp support in the parser") { + checkAnswer(sql( + "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"), + Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + + checkAnswer(sql( + "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"), + Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + + checkAnswer(sql( + "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"), + Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + + checkAnswer(sql( + """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003' + AND time>'1970-01-01 00:00:00.001'"""), + Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + + checkAnswer(sql( + "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"), + Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), + Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + + checkAnswer(sql( + "SELECT time FROM timestamps WHERE time='123'"), + Nil) + } + test("index into array") { checkAnswer( sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), @@ -304,6 +379,25 @@ class SQLQuerySuite extends QueryTest { (null, null, 6, "F") :: Nil) } + test("SPARK-3349 partitioning after limit") { + /* + sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") + .limit(2) + .registerTempTable("subset1") + sql("SELECT DISTINCT n FROM lowerCaseData") + .limit(2) + .registerTempTable("subset2") + checkAnswer( + sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), + (3, "c", 3) :: + (4, "d", 4) :: Nil) + checkAnswer( + sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), + (1, "a", 1) :: + (2, "b", 2) :: Nil) + */ + } + test("mixed-case keywords") { checkAnswer( sql( @@ -384,18 +478,48 @@ class SQLQuerySuite extends QueryTest { (3, null))) } - test("EXCEPT") { + test("UNION") { + checkAnswer( + sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), + (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: + (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + checkAnswer( + sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), + (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil) + checkAnswer( + sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), + (1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") :: + (4, "d") :: (4, "d") :: Nil) + } + test("UNION with column mismatches") { + // Column name mismatches are allowed. + checkAnswer( + sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), + (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: + (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + // Column type mismatches are not allowed, forcing a type coercion. checkAnswer( - sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData "), + sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), + ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_))) + // Column type mismatches where a coercion is not possible, in this case between integer + // and array types, trigger a TreeNodeException. + intercept[TreeNodeException[_]] { + sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() + } + } + + test("EXCEPT") { + checkAnswer( + sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil) checkAnswer( - sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData "), Nil) + sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) checkAnswer( - sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData "), Nil) + sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil) } test("INTERSECT") { @@ -525,4 +649,28 @@ class SQLQuerySuite extends QueryTest { (3, null) :: (4, 2147483644) :: Nil) } + + test("SPARK-3423 BETWEEN") { + checkAnswer( + sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), + Seq((5, "5"), (6, "6"), (7, "7")) + ) + + checkAnswer( + sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), + Seq((7, "7")) + ) + + checkAnswer( + sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), + Seq() + ) + } + + test("cast boolean to string") { + // TODO Ensure true/false string letter casing is consistent with Hive in all cases. + checkAnswer( + sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), + ("true", "false") :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 5b84c658db942..e24c521d24c7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext._ case class ReflectData( @@ -56,6 +57,22 @@ case class OptionalReflectData( case class ReflectBinary(data: Array[Byte]) +case class Nested(i: Option[Int], s: String) + +case class Data( + array: Seq[Int], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapContainsNul: Map[Int, Option[Long]], + nested: Nested) + +case class ComplexReflectData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[Option[Int]], + mapField: Map[Int, Long], + mapFieldContainsNull: Map[Int, Option[Long]], + dataField: Data) + class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, @@ -90,4 +107,33 @@ class ScalaReflectionRelationSuite extends FunSuite { val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } + + test("query complex data") { + val data = ComplexReflectData( + Seq(1, 2, 3), + Seq(Some(1), Some(2), None), + Map(1 -> 10L, 2 -> 20L), + Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), + Data( + Seq(10, 20, 30), + Seq(Some(10), Some(20), None), + Map(10 -> 100L, 20 -> 200L), + Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), + Nested(None, "abc"))) + val rdd = sparkContext.parallelize(data :: Nil) + rdd.registerTempTable("reflectComplexData") + + assert(sql("SELECT * FROM reflectComplexData").collect().head === + new GenericRow(Array[Any]( + Seq(1, 2, 3), + Seq(1, 2, null), + Map(1 -> 10L, 2 -> 20L), + Map(1 -> 10L, 2 -> 20L, 3 -> null), + new GenericRow(Array[Any]( + Seq(10, 20, 30), + Seq(10, 20, null), + Map(10 -> 100L, 20 -> 200L), + Map(10 -> 100L, 20 -> 200L, 30 -> null), + new GenericRow(Array[Any](null, "abc"))))))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index c3ec82fb69778..eb33a61c6e811 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -151,4 +151,9 @@ object TestData { TimestampField(new Timestamp(i)) }) timestamps.registerTempTable("timestamps") + + case class IntField(i: Int) + // An RDD with 4 elements and 8 partitions + val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + withEmptyParts.registerTempTable("withEmptyParts") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 76aa9b0081d7e..ef9b76b1e251e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +case class FunctionResult(f1: String, f2: String) + class UDFSuite extends QueryTest { test("Simple UDF") { @@ -33,4 +35,14 @@ class UDFSuite extends QueryTest { registerFunction("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) } + + + test("struct UDF") { + registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + + val result= + sql("SELECT returnStruct('test', 'test2') as ret") + .select("ret.f1".attr).first().getString(0) + assert(result == "test") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 5f61fb5e16ea3..cde91ceb68c98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,29 +19,30 @@ package org.apache.spark.sql.columnar import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.types._ class ColumnStatsSuite extends FunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN) - testColumnStats(classOf[ByteColumnStats], BYTE) - testColumnStats(classOf[ShortColumnStats], SHORT) - testColumnStats(classOf[IntColumnStats], INT) - testColumnStats(classOf[LongColumnStats], LONG) - testColumnStats(classOf[FloatColumnStats], FLOAT) - testColumnStats(classOf[DoubleColumnStats], DOUBLE) - testColumnStats(classOf[StringColumnStats], STRING) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP) - - def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]]( + testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) + + def testColumnStats[T <: NativeType, U <: ColumnStats]( columnStatsClass: Class[U], - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + initialStatistics: Row) { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - assertResult(columnStats.initialBounds, "Wrong initial bounds") { - (columnStats.lowerBound, columnStats.upperBound) + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => + assert(actual === expected) } } @@ -49,14 +50,16 @@ class ColumnStatsSuite extends FunSuite { import ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() - val rows = Seq.fill(10)(makeRandomRow(columnType)) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.map(_.head.asInstanceOf[T#JvmType]) + val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound) - assertResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound) + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 736c0f8571e9e..0e3c67f5eed29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar -import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.execution.SparkLogicalPlan import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{SQLConf, QueryTest, TestData} class InMemoryColumnarQuerySuite extends QueryTest { - import TestData._ - import TestSQLContext._ + import org.apache.spark.sql.TestData._ + import org.apache.spark.sql.test.TestSQLContext._ test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan @@ -33,6 +32,14 @@ class InMemoryColumnarQuerySuite extends QueryTest { checkAnswer(scan, testData.collect().toSeq) } + test("default size avoids broadcast") { + // TODO: Improve this test when we have better statistics + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst") + cacheTable("sizeTst") + assert( + table("sizeTst").queryExecution.logical.statistics.sizeInBytes > autoBroadcastJoinThreshold) + } + test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, plan) @@ -85,4 +92,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) } + + test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { + checkAnswer( + sql("SELECT * FROM withEmptyParts"), + withEmptyParts.collect().toSeq) + + TestSQLContext.cacheTable("withEmptyParts") + + checkAnswer( + sql("SELECT * FROM withEmptyParts"), + withEmptyParts.collect().toSeq) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index dc813fe146c47..a77262534a352 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkSqlSerializer class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala new file mode 100644 index 0000000000000..5d2fd4959197c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql._ +import org.apache.spark.sql.test.TestSQLContext._ + +case class IntegerData(i: Int) + +class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { + val originalColumnBatchSize = columnBatchSize + val originalInMemoryPartitionPruning = inMemoryPartitionPruning + + override protected def beforeAll() { + // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch + setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) + rawData.registerTempTable("intData") + + // Enable in-memory partition pruning + setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + } + + override protected def afterAll() { + setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + } + + before { + cacheTable("intData") + } + + after { + uncacheTable("intData") + } + + // Comparisons + checkBatchPruning("i = 1", Seq(1), 1, 1) + checkBatchPruning("1 = i", Seq(1), 1, 1) + checkBatchPruning("i < 12", 1 to 11, 1, 2) + checkBatchPruning("i <= 11", 1 to 11, 1, 2) + checkBatchPruning("i > 88", 89 to 100, 1, 2) + checkBatchPruning("i >= 89", 89 to 100, 1, 2) + checkBatchPruning("12 > i", 1 to 11, 1, 2) + checkBatchPruning("11 >= i", 1 to 11, 1, 2) + checkBatchPruning("88 < i", 89 to 100, 1, 2) + checkBatchPruning("89 <= i", 89 to 100, 1, 2) + + // Conjunction and disjunction + checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) + checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) + checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) + + // With unsupported predicate + checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) + checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10) + + def checkBatchPruning( + filter: String, + expectedQueryResult: Seq[Int], + expectedReadPartitions: Int, + expectedReadBatches: Int) { + + test(filter) { + val query = sql(s"SELECT * FROM intData WHERE $filter") + assertResult(expectedQueryResult.toArray, "Wrong query result") { + query.collect().map(_.head).toArray + } + + val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect { + case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) + }.head + + assert(readBatches === expectedReadBatches, "Wrong number of read batches") + assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 5fba00480967c..e01cc8b4d20f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.Row -import org.apache.spark.sql.columnar.{BOOLEAN, BooleanColumnStats} +import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ class BooleanBitSetSuite extends FunSuite { @@ -31,7 +31,7 @@ class BooleanBitSetSuite extends FunSuite { // Tests encoder // ------------- - val builder = TestCompressibleColumnBuilder(new BooleanColumnStats, BOOLEAN, BooleanBitSet) + val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) val values = rows.map(_.head) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index d8ae2a26778c9..d2969d906c943 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -31,7 +31,7 @@ class DictionaryEncodingSuite extends FunSuite { testDictionaryEncoding(new StringColumnStats, STRING) def testDictionaryEncoding[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T]) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 17619dcf974e3..322f447c24840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -29,7 +29,7 @@ class IntegralDeltaSuite extends FunSuite { testIntegralDelta(new LongColumnStats, LONG, LongDelta) def testIntegralDelta[I <: IntegralType]( - columnStats: NativeColumnStats[I], + columnStats: ColumnStats, columnType: NativeColumnType[I], scheme: IntegralDelta[I]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 40115beb98899..218c09ac26362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ class RunLengthEncodingSuite extends FunSuite { - testRunLengthEncoding(new BooleanColumnStats, BOOLEAN) + testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) testRunLengthEncoding(new IntColumnStats, INT) @@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new StringColumnStats, STRING) def testRunLengthEncoding[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T]) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 72c19fa31d980..7db723d648d80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ class TestCompressibleColumnBuilder[T <: NativeType]( - override val columnStats: NativeColumnStats[T], + override val columnStats: ColumnStats, override val columnType: NativeColumnType[T], override val schemes: Seq[CompressionScheme]) extends NativeColumnBuilder(columnStats, columnType) @@ -33,7 +33,7 @@ class TestCompressibleColumnBuilder[T <: NativeType]( object TestCompressibleColumnBuilder { def apply[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T], scheme: CompressionScheme) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 76b1724471442..37d64f0de7bab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -45,16 +45,16 @@ class PlannerSuite extends FunSuite { assert(aggregations.size === 2) } - test("count distinct is not partially aggregated") { + test("count distinct is partially aggregated") { val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed val planned = HashAggregation(query) - assert(planned.isEmpty) + assert(planned.nonEmpty) } - test("mixed aggregates are not partially aggregated") { + test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed val planned = HashAggregation(query) - assert(planned.isEmpty) + assert(planned.nonEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 58b1e23891a3b..b50d93855405a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -130,11 +130,11 @@ class JsonSuite extends QueryTest { checkDataType( ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) checkDataType( - ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false)) + ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, true)) checkDataType( ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false)) checkDataType( - ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType)) + ArrayType(IntegerType, false), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) // StructType checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) @@ -201,26 +201,26 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonRDD(complexFieldAndType) val expectedSchema = StructType( - StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) :: - StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType), true) :: - StructField("arrayOfBoolean", ArrayType(BooleanType), true) :: - StructField("arrayOfDouble", ArrayType(DoubleType), true) :: - StructField("arrayOfInteger", ArrayType(IntegerType), true) :: - StructField("arrayOfLong", ArrayType(LongType), true) :: + StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: + StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: + StructField("arrayOfLong", ArrayType(LongType, false), true) :: StructField("arrayOfNull", ArrayType(StringType, true), true) :: - StructField("arrayOfString", ArrayType(StringType), true) :: + StructField("arrayOfString", ArrayType(StringType, false), true) :: StructField("arrayOfStruct", ArrayType( StructType( StructField("field1", BooleanType, true) :: StructField("field2", StringType, true) :: - StructField("field3", StringType, true) :: Nil)), true) :: + StructField("field3", StringType, true) :: Nil), false), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: StructField("field2", DecimalType, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( - StructField("field1", ArrayType(IntegerType), true) :: - StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil) + StructField("field1", ArrayType(IntegerType, false), true) :: + StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) assert(expectedSchema === jsonSchemaRDD.schema) @@ -441,7 +441,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) val expectedSchema = StructType( - StructField("array", ArrayType(IntegerType), true) :: + StructField("array", ArrayType(IntegerType, false), true) :: StructField("num_struct", StringType, true) :: StructField("str_array", StringType, true) :: StructField("struct", StructType( @@ -467,7 +467,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: StructField("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil)), true) :: Nil) + StructField("field", LongType, true) :: Nil), false), true) :: Nil) assert(expectedSchema === jsonSchemaRDD.schema) @@ -492,7 +492,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("a", BooleanType, true) :: StructField("b", LongType, true) :: - StructField("c", ArrayType(IntegerType), true) :: + StructField("c", ArrayType(IntegerType, false), true) :: StructField("d", StructType( StructField("field", BooleanType, true) :: Nil), true) :: StructField("e", StringType, true) :: Nil) @@ -581,4 +581,45 @@ class JsonSuite extends QueryTest { "this is a simple string.") :: Nil ) } + + test("SPARK-2096 Correctly parse dot notations") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType2) + jsonSchemaRDD.registerTempTable("jsonTable") + + checkAnswer( + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + (true, "str1") :: Nil + ) + checkAnswer( + sql( + """ + |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] + |from jsonTable + """.stripMargin), + ("str2", 6) :: Nil + ) + } + + test("SPARK-3390 Complex arrays") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType2) + jsonSchemaRDD.registerTempTable("jsonTable") + + checkAnswer( + sql( + """ + |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] + |from jsonTable + """.stripMargin), + (5, 7, 8) :: Nil + ) + checkAnswer( + sql( + """ + |select arrayOfArray2[0][0][0].inner1, arrayOfArray2[1][0], + |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 + |from jsonTable + """.stripMargin), + ("str1", Nil, "str4", 2) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index a88310b5f1b46..5f0b3959a63ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -82,4 +82,58 @@ object TestJsonData { """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) + + val complexFieldAndType2 = + TestSQLContext.sparkContext.parallelize( + """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], + "complexArrayOfStruct": [ + { + "field1": [ + { + "inner1": "str1" + }, + { + "inner2": ["str2", "str22"] + }], + "field2": [[1, 2], [3, 4]] + }, + { + "field1": [ + { + "inner2": ["str3", "str33"] + }, + { + "inner1": "str4" + }], + "field2": [[5, 6], [7, 8]] + }], + "arrayOfArray1": [ + [ + [5] + ], + [ + [6, 7], + [8] + ]], + "arrayOfArray2": [ + [ + [ + { + "inner1": "str1" + } + ] + ], + [ + [], + [ + {"inner2": ["str3", "str33"]}, + {"inner2": ["str4"], "inner1": "str11"} + ] + ], + [ + [ + {"inner3": [[{"inner4": 2}]]} + ] + ]] + }""" :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 172dcd6aa0ee3..b0a06cd3ca090 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,25 +17,19 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} - import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job - -import org.apache.spark.SparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} +import org.apache.spark.sql.catalyst.types.IntegerType import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils - case class TestRDDEntry(key: Int, value: String) case class NullReflectData( @@ -78,7 +72,9 @@ case class AllDataTypesWithNonPrimitiveType( booleanField: Boolean, binaryField: Array[Byte], array: Seq[Int], - map: Map[Int, String], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapValueContainsNull: Map[Int, Option[Long]], data: Data) class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { @@ -86,11 +82,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA var testRDD: SchemaRDD = null - // TODO: remove this once SqlParser can parse nested select statements - var nestedParserSqlContext: NestedParserSQLContext = null - override def beforeAll() { - nestedParserSqlContext = new NestedParserSQLContext(TestSQLContext.sparkContext) ParquetTestData.writeFile() ParquetTestData.writeFilterFile() ParquetTestData.writeNestedFile1() @@ -186,6 +178,100 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) } + test("Compression options for writing to a Parquetfile") { + val defaultParquetCompressionCodec = TestSQLContext.parquetCompressionCodec + import scala.collection.JavaConversions._ + + val file = getTempFilePath("parquet") + val path = file.toString + val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + + // test default compression codec + rdd.saveAsParquetFile(path) + var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) + .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct + assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + + parquetFile(path).registerTempTable("tmp") + checkAnswer( + sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + Utils.deleteRecursively(file) + + // test uncompressed parquet file with property value "UNCOMPRESSED" + TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "UNCOMPRESSED") + + rdd.saveAsParquetFile(path) + actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) + .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct + assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + + parquetFile(path).registerTempTable("tmp") + checkAnswer( + sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + Utils.deleteRecursively(file) + + // test uncompressed parquet file with property value "none" + TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "none") + + rdd.saveAsParquetFile(path) + actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) + .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct + assert(actualCodec === "UNCOMPRESSED" :: Nil) + + parquetFile(path).registerTempTable("tmp") + checkAnswer( + sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + Utils.deleteRecursively(file) + + // test gzip compression codec + TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "gzip") + + rdd.saveAsParquetFile(path) + actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) + .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct + assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + + parquetFile(path).registerTempTable("tmp") + checkAnswer( + sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + Utils.deleteRecursively(file) + + // test snappy compression codec + TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "snappy") + + rdd.saveAsParquetFile(path) + actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) + .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct + assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + + parquetFile(path).registerTempTable("tmp") + checkAnswer( + sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + Utils.deleteRecursively(file) + + // TODO: Lzo requires additional external setup steps so leave it out for now + // ref.: https://github.com/Parquet/parquet-mr/blob/parquet-1.5.0/parquet-hadoop/src/test/java/parquet/hadoop/example/TestInputOutputFormat.java#L169 + + // Set it back. + TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, defaultParquetCompressionCodec) + } + test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) @@ -193,7 +279,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .map(x => AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, (0 to x).map(_.toByte).toArray, - (0 until x), (0 until x).map(i => i -> s"$i").toMap, Data((0 until x), Nested(x, s"$x")))) + (0 until x), + (0 until x).map(Option(_).filter(_ % 3 == 0)), + (0 until x).map(i => i -> i.toLong).toMap, + (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), + Data((0 until x), Nested(x, s"$x")))) .saveAsParquetFile(tempDir) val result = parquetFile(tempDir).collect() range.foreach { @@ -208,8 +298,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result(i).getBoolean(7) === (i % 2 == 0)) assert(result(i)(8) === (0 to i).map(_.toByte).toArray) assert(result(i)(9) === (0 until i)) - assert(result(i)(10) === (0 until i).map(i => i -> s"$i").toMap) - assert(result(i)(11) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) + assert(result(i)(10) === (0 until i).map(i => if (i % 3 == 0) i else null)) + assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap) + assert(result(i)(12) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null)) + assert(result(i)(13) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) } } @@ -318,8 +410,30 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val rdd_copy = sql("SELECT * FROM tmpx").collect() val rdd_orig = rdd.collect() for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value in line $i") + assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") + assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") + } + Utils.deleteRecursively(file) + } + + test("Read a parquet file instead of a directory") { + val file = getTempFilePath("parquet") + val path = file.toString + val fsPath = new Path(path) + val fs: FileSystem = fsPath.getFileSystem(TestSQLContext.sparkContext.hadoopConfiguration) + val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + rdd.coalesce(1).saveAsParquetFile(path) + + val children = fs.listStatus(fsPath).filter(_.getPath.getName.endsWith(".parquet")) + assert(children.length > 0) + val readFile = parquetFile(path + "/" + children(0).getPath.getName) + readFile.registerTempTable("tmpx") + val rdd_copy = sql("SELECT * FROM tmpx").collect() + val rdd_orig = rdd.collect() + for(i <- 0 to 99) { + assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") + assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") } Utils.deleteRecursively(file) } @@ -595,11 +709,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection in addressbook") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD data.registerTempTable("data") - val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") + val query = sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() assert(tmp.size === 2) assert(tmp(0).size === 2) @@ -610,21 +722,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Simple query on nested int data") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir2.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD data.registerTempTable("data") - val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() + val result1 = sql("SELECT entries[0].value FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) assert(result1(0)(0) === 2.5) - val result2 = nestedParserSqlContext.sql("SELECT entries[0] FROM data").collect() + val result2 = sql("SELECT entries[0] FROM data").collect() assert(result2.size === 1) val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] assert(subresult1.size === 2) assert(subresult1(0) === 2.5) assert(subresult1(1) === false) - val result3 = nestedParserSqlContext.sql("SELECT outerouter FROM data").collect() + val result3 = sql("SELECT outerouter FROM data").collect() val subresult2 = result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] @@ -637,19 +747,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("nested structs") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir3.toString) + val data = parquetFile(ParquetTestData.testNestedDir3.toString) .toSchemaRDD data.registerTempTable("data") - val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() + val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) assert(result1(0)(0) === false) - val result2 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() + val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() assert(result2.size === 1) assert(result2(0).size === 1) assert(result2(0)(0) === true) - val result3 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() + val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() assert(result3.size === 1) assert(result3(0).size === 1) assert(result3(0)(0) === false) @@ -673,11 +782,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("map with struct values") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD data.registerTempTable("mapTable") - val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() + val result1 = sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) val entry1 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] @@ -691,7 +798,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) - val result2 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() + val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() assert(result2.size === 1) assert(result2(0)(0) === 42.toLong) assert(result2(0)(1) === "the answer") @@ -702,15 +809,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA // has no effect in this test case val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) - val result = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD + val result = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD result.saveAsParquetFile(tmpdir.toString) - nestedParserSqlContext - .parquetFile(tmpdir.toString) + parquetFile(tmpdir.toString) .toSchemaRDD .registerTempTable("tmpcopy") - val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() + val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() assert(tmpdata.size === 2) assert(tmpdata(0).size === 2) assert(tmpdata(0)(0) === "Julien Le Dem") @@ -721,20 +825,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Writing out Map and reading it back in") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) data.saveAsParquetFile(tmpdir.toString) - nestedParserSqlContext - .parquetFile(tmpdir.toString) + parquetFile(tmpdir.toString) .toSchemaRDD .registerTempTable("tmpmapcopy") - val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() + val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() assert(result1.size === 1) assert(result1(0)(0) === 2) - val result2 = nestedParserSqlContext.sql("SELECT data2 FROM tmpmapcopy").collect() + val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() assert(result2.size === 1) val entry1 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] @@ -748,42 +849,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) - val result3 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() + val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() assert(result3.size === 1) assert(result3(0)(0) === 42.toLong) assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } } - -// TODO: the code below is needed temporarily until the standard parser is able to parse -// nested field expressions correctly -class NestedParserSQLContext(@transient override val sparkContext: SparkContext) extends SQLContext(sparkContext) { - override protected[sql] val parser = new NestedSqlParser() -} - -class NestedSqlLexical(override val keywords: Seq[String]) extends SqlLexical(keywords) { - override def identChar = letter | elem('_') - delimiters += (".") -} - -class NestedSqlParser extends SqlParser { - override val lexical = new NestedSqlLexical(reservedWords) - - override protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { - case base ~ _ ~ ordinal => GetItem(base, ordinal) - } | - expression ~ "." ~ ident ^^ { - case base ~ _ ~ fieldName => GetField(base, fieldName) - } | - TRUE ^^^ Literal(true, BooleanType) | - FALSE ^^^ Literal(false, BooleanType) | - cast | - "(" ~> expression <~ ")" | - function | - "-" ~> literal ^^ UnaryMinus | - ident ^^ UnresolvedAttribute | - "*" ^^^ Star(None) | - literal -} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index c6f60c18804a4..124fc107cb8aa 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 699a1103f3248..bd3f68d92d8c7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -39,7 +39,9 @@ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. */ -class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging { +private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) + extends OperationManager with Logging { + val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") @@ -66,9 +68,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage if (!iter.hasNext) { new RowSet() } else { - val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No. + // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int + val maxRows = maxRowsL.toInt var curRow = 0 - var rowSet = new ArrayBuffer[Row](maxRows) + var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) while (curRow < maxRows && iter.hasNext) { val sparkRow = iter.next() @@ -151,7 +154,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def getResultSetSchema: TableSchema = { - logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 2bf8cfdcacd22..3475c2c9db080 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,41 +18,112 @@ package org.apache.spark.sql.hive.thriftserver -import java.io.{BufferedReader, InputStreamReader, PrintWriter} +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration._ +import scala.concurrent.{Await, Future, Promise} +import scala.sys.process.{Process, ProcessLogger} + +import java.io._ +import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} -class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { - val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli") - val METASTORE_PATH = TestUtils.getMetastorePath("cli") +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.util.getTempFilePath + +class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { + def runCliWithin( + timeout: FiniteDuration, + extraArgs: Seq[String] = Seq.empty)( + queriesAndExpectedAnswers: (String, String)*) { + + val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip + val warehousePath = getTempFilePath("warehouse") + val metastorePath = getTempFilePath("metastore") + val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) - override def beforeAll() { - val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" - val commands = - s"""../../bin/spark-sql + val command = { + val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" + s"""$cliScript | --master local - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$WAREHOUSE_PATH - """.stripMargin.split("\\s+") - - val pb = new ProcessBuilder(commands: _*) - process = pb.start() - outputWriter = new PrintWriter(process.getOutputStream, true) - inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) - errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) - waitForOutput(inputReader, "spark-sql>") + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + """.stripMargin.split("\\s+").toSeq ++ extraArgs + } + + // AtomicInteger is needed because stderr and stdout of the forked process are handled in + // different threads. + val next = new AtomicInteger(0) + val foundAllExpectedAnswers = Promise.apply[Unit]() + val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) + val buffer = new ArrayBuffer[String]() + + def captureOutput(source: String)(line: String) { + buffer += s"$source> $line" + if (line.contains(expectedAnswers(next.get()))) { + if (next.incrementAndGet() == expectedAnswers.size) { + foundAllExpectedAnswers.trySuccess(()) + } + } + } + + // Searching expected output line from both stdout and stderr of the CLI process + val process = (Process(command) #< queryStream).run( + ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) + + Future { + val exitValue = process.exitValue() + logInfo(s"Spark SQL CLI process exit value: $exitValue") + } + + try { + Await.result(foundAllExpectedAnswers.future, timeout) + } catch { case cause: Throwable => + logError( + s""" + |======================= + |CliSuite failure output + |======================= + |Spark SQL CLI command line: ${command.mkString(" ")} + | + |Executed query ${next.get()} "${queries(next.get())}", + |But failed to capture expected output "${expectedAnswers(next.get())}" within $timeout. + | + |${buffer.mkString("\n")} + |=========================== + |End CliSuite failure output + |=========================== + """.stripMargin, cause) + } finally { + warehousePath.delete() + metastorePath.delete() + process.destroy() + } } - override def afterAll() { - process.destroy() - process.waitFor() + test("Simple commands") { + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + runCliWithin(1.minute)( + "CREATE TABLE hive_test(key INT, val STRING);" + -> "OK", + "SHOW TABLES;" + -> "hive_test", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;" + -> "OK", + "CACHE TABLE hive_test;" + -> "Time taken: ", + "SELECT COUNT(*) FROM hive_test;" + -> "5", + "DROP TABLE hive_test" + -> "Time taken: " + ) } - test("simple commands") { - val dataFilePath = getDataFile("data/files/small_kv.txt") - executeQuery("create table hive_test1(key int, val string);") - executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;") - executeQuery("cache table hive_test1", "Time taken") + test("Single command with -e") { + runCliWithin(1.minute, Seq("-e", "SHOW TABLES;"))("" -> "OK") } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index aedef6ce1f5f2..38977ff162097 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -17,32 +17,32 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent._ +import scala.concurrent.duration._ +import scala.concurrent.{Await, Future, Promise} +import scala.sys.process.{Process, ProcessLogger} -import java.io.{BufferedReader, InputStreamReader} +import java.io.File import java.net.ServerSocket -import java.sql.{Connection, DriverManager, Statement} +import java.sql.{DriverManager, Statement} +import java.util.concurrent.TimeoutException import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.apache.hive.jdbc.HiveDriver +import org.scalatest.FunSuite import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath /** - * Test for the HiveThriftServer2 using JDBC. + * Tests for the HiveThriftServer2 using JDBC. */ -class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging { +class HiveThriftServer2Suite extends FunSuite with Logging { + Class.forName(classOf[HiveDriver].getCanonicalName) - val WAREHOUSE_PATH = getTempFilePath("warehouse") - val METASTORE_PATH = getTempFilePath("metastore") - - val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver" - val TABLE = "test" - val HOST = "localhost" - val PORT = { + private val listeningHost = "localhost" + private val listeningPort = { // Let the system to choose a random available port to avoid collision with other parallel // builds. val socket = new ServerSocket(0) @@ -51,106 +51,126 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt port } - // If verbose is true, the test program will print all outputs coming from the Hive Thrift server. - val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean - - Class.forName(DRIVER_NAME) + private val warehousePath = getTempFilePath("warehouse") + private val metastorePath = getTempFilePath("metastore") + private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - override def beforeAll() { launchServer() } + def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) { + val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) - override def afterAll() { stopServer() } - - private def launchServer(args: Seq[String] = Seq.empty) { - // Forking a new process to start the Hive Thrift server. The reason to do this is it is - // hard to clean up Hive resources entirely, so we just start a new process and kill - // that process for cleanup. - val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" val command = - s"""../../sbin/start-thriftserver.sh + s"""$serverScript | --master local | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$METASTORE_PATH - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$HOST - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$PORT - """.stripMargin.split("\\s+") - - val pb = new ProcessBuilder(command ++ args: _*) - val environment = pb.environment() - environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) - environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) - process = pb.start() - inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) - errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) - waitForOutput(inputReader, "ThriftBinaryCLIService listening on") - - // Spawn a thread to read the output from the forked process. - // Note that this is necessary since in some configurations, log4j could be blocked - // if its output to stderr are not read, and eventually blocking the entire test suite. - future { - while (true) { - val stdout = readFrom(inputReader) - val stderr = readFrom(errorReader) - if (VERBOSE && stdout.length > 0) { - println(stdout) - } - if (VERBOSE && stderr.length > 0) { - println(stderr) - } - Thread.sleep(50) + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort + """.stripMargin.split("\\s+").toSeq + + val serverStarted = Promise[Unit]() + val buffer = new ArrayBuffer[String]() + + def captureOutput(source: String)(line: String) { + buffer += s"$source> $line" + if (line.contains("ThriftBinaryCLIService listening on")) { + serverStarted.success(()) } } - } - private def stopServer() { - process.destroy() - process.waitFor() + val process = Process(command).run( + ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) + + Future { + val exitValue = process.exitValue() + logInfo(s"Spark SQL Thrift server process exit value: $exitValue") + } + + val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/" + val user = System.getProperty("user.name") + + try { + Await.result(serverStarted.future, timeout) + + val connection = DriverManager.getConnection(jdbcUri, user, "") + val statement = connection.createStatement() + + try { + f(statement) + } finally { + statement.close() + connection.close() + } + } catch { + case cause: Exception => + cause match { + case _: TimeoutException => + logError(s"Failed to start Hive Thrift server within $timeout", cause) + case _ => + } + logError( + s""" + |===================================== + |HiveThriftServer2Suite failure output + |===================================== + |HiveThriftServer2 command line: ${command.mkString(" ")} + |JDBC URI: $jdbcUri + |User: $user + | + |${buffer.mkString("\n")} + |========================================= + |End HiveThriftServer2Suite failure output + |========================================= + """.stripMargin, cause) + } finally { + warehousePath.delete() + metastorePath.delete() + process.destroy() + } } - test("test query execution against a Hive Thrift server") { - Thread.sleep(5 * 1000) - val dataFilePath = getDataFile("data/files/small_kv.txt") - val stmt = createStatement() - stmt.execute("DROP TABLE IF EXISTS test") - stmt.execute("DROP TABLE IF EXISTS test_cached") - stmt.execute("CREATE TABLE test(key INT, val STRING)") - stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") - stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4") - stmt.execute("CACHE TABLE test_cached") - - var rs = stmt.executeQuery("SELECT COUNT(*) FROM test") - rs.next() - assert(rs.getInt(1) === 5) - - rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached") - rs.next() - assert(rs.getInt(1) === 4) - - stmt.close() + test("Test JDBC query execution") { + startThriftServerWithin() { statement => + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + val queries = Seq( + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } } test("SPARK-3004 regression: result set containing NULL") { - Thread.sleep(5 * 1000) - val dataFilePath = getDataFile("data/files/small_kv_with_null.txt") - val stmt = createStatement() - stmt.execute("DROP TABLE IF EXISTS test_null") - stmt.execute("CREATE TABLE test_null(key INT, val STRING)") - stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null") - - val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL") - var count = 0 - while (rs.next()) { - count += 1 - } - assert(count === 5) + startThriftServerWithin() { statement => + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource( + "data/files/small_kv_with_null.txt") - stmt.close() - } + val queries = Seq( + "DROP TABLE IF EXISTS test_null", + "CREATE TABLE test_null(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null") - def getConnection: Connection = { - val connectURI = s"jdbc:hive2://localhost:$PORT/" - DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") - } + queries.foreach(statement.execute) - def createStatement(): Statement = getConnection.createStatement() + val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + + (0 until 5).foreach { _ => + resultSet.next() + assert(resultSet.getInt(1) === 0) + assert(resultSet.wasNull()) + } + + assert(!resultSet.next()) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala deleted file mode 100644 index bb2242618fbef..0000000000000 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.io.{BufferedReader, PrintWriter} -import java.text.SimpleDateFormat -import java.util.Date - -import org.apache.hadoop.hive.common.LogUtils -import org.apache.hadoop.hive.common.LogUtils.LogInitializationException - -object TestUtils { - val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss") - - def getWarehousePath(prefix: String): String = { - System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" + - timestamp.format(new Date) - } - - def getMetastorePath(prefix: String): String = { - System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" + - timestamp.format(new Date) - } - - // Dummy function for initialize the log4j properties. - def init() { } - - // initialize log4j - try { - LogUtils.initHiveLog4j() - } catch { - case e: LogInitializationException => // Ignore the error. - } -} - -trait TestUtils { - var process : Process = null - var outputWriter : PrintWriter = null - var inputReader : BufferedReader = null - var errorReader : BufferedReader = null - - def executeQuery( - cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = { - println("Executing: " + cmd + ", expecting output: " + outputMessage) - outputWriter.write(cmd + "\n") - outputWriter.flush() - waitForQuery(timeout, outputMessage) - } - - protected def waitForQuery(timeout: Long, message: String): String = { - if (waitForOutput(errorReader, message, timeout)) { - Thread.sleep(500) - readOutput() - } else { - assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput()) - null - } - } - - // Wait for the specified str to appear in the output. - protected def waitForOutput( - reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = { - val startTime = System.currentTimeMillis - var out = "" - while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) { - out += readFrom(reader) - } - out.contains(str) - } - - // Read stdout output and filter out garbage collection messages. - protected def readOutput(): String = { - val output = readFrom(inputReader) - // Remove GC Messages - val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC")) - .mkString("\n") - filteredOutput - } - - protected def readFrom(reader: BufferedReader): String = { - var out = "" - var c = 0 - while (reader.ready) { - c = reader.read() - out += c.asInstanceOf[Char] - } - out - } - - protected def getDataFile(name: String) = { - Thread.currentThread().getContextClassLoader.getResource(name) - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 210753efe7678..ab487d673e813 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.hive.execution import java.io.File -import java.util.TimeZone +import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive /** @@ -29,23 +30,34 @@ import org.apache.spark.sql.hive.test.TestHive */ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath - lazy val hiveQueryDir = TestHive.getHiveFile("ql" + File.separator + "src" + - File.separator + "test" + File.separator + "queries" + File.separator + "clientpositive") + private lazy val hiveQueryDir = TestHive.getHiveFile( + "ql/src/test/queries/clientpositive".split("/").mkString(File.separator)) - var originalTimeZone: TimeZone = _ + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + private val originalColumnBatchSize = TestHive.columnBatchSize + private val originalInMemoryPartitionPruning = TestHive.inMemoryPartitionPruning def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) override def beforeAll() { TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) - originalTimeZone = TimeZone.getDefault TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + // Set a relatively small column batch size for testing purposes + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5") + // Enable in-memory partition pruning for testing purposes + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ @@ -310,6 +322,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "create_nested_type", "create_skewed_table1", "create_struct_table", + "cross_join", "ct_case_insensitive", "database_location", "database_properties", @@ -643,9 +656,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "show_create_table_db_table", "show_create_table_does_not_exist", "show_create_table_index", + "show_columns", "show_describe_func_quotes", "show_functions", "show_partitions", + "show_tblproperties", "skewjoinopt13", "skewjoinopt18", "skewjoinopt9", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 30ff277e67c88..45a4c6dc98da0 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ff32c7c90a0d2..e0be09e6793ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -255,13 +255,20 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } // Note that HiveUDFs will be overridden by functions registered in this context. + @transient override protected[sql] lazy val functionRegistry = new HiveFunctionRegistry with OverrideFunctionRegistry /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) { + override val extendedRules = + catalog.CreateTables :: + catalog.PreInsertionCasts :: + ExtractPythonUdfs :: + Nil + } /** * Runs the specified SQL query using Hive. @@ -352,9 +359,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { - // TODO: Create mixin for the analyzer instead of overriding things here. - override lazy val optimizedPlan = - optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))) override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) @@ -388,7 +392,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -408,7 +412,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // be similar with Hive. describeHiveTableCommand.hiveString case command: PhysicalCommand => - command.sideEffectResult.map(_.toString) + command.sideEffectResult.map(_.head.toString) case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq @@ -423,7 +427,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { logical match { case _: NativeCommand => "" case _: SetCommand => "" - case _ => executedPlan.toString + case _ => super.simpleString } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 3b371211e14cd..2c0db9be57e54 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -54,8 +54,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with db: Option[String], tableName: String, alias: Option[String]): LogicalPlan = synchronized { - val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + val (databaseName, tblName) = processDatabaseAndTableName( + db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) val table = client.getTable(databaseName, tblName) val partitions: Seq[Partition] = if (table.isPartitioned) { @@ -109,18 +109,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with */ object CreateTables extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case InsertIntoCreatedTable(db, tableName, child) => + // Wait until children are resolved. + case p: LogicalPlan if !p.childrenResolved => p + + case CreateTableAsSelect(db, tableName, child) => val (dbName, tblName) = processDatabaseAndTableName(db, tableName) val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - createTable(databaseName, tblName, child.output) - - InsertIntoTable( - EliminateAnalysisOperators( - lookupRelation(Some(databaseName), tblName, None)), - Map.empty, - child, - overwrite = false) + CreateTableAsSelect(Some(databaseName), tableName, child) } } @@ -130,15 +126,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with */ object PreInsertionCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transform { - // Wait until children are resolved + // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => + case p @ InsertIntoTable( + LowerCaseSchema(table: MetastoreRelation), _, child, _) => castChildOutput(p, table, child) case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, _, - HiveTableScan(_, table, _)), _, child, _) => + LowerCaseSchema( + InMemoryRelation(_, _, _, + HiveTableScan(_, table, _))), _, child, _) => castChildOutput(p, table, child) } @@ -265,9 +263,9 @@ private[hive] case class MetastoreRelation // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException // which indicates the SerDe we used is not Serializable. - @transient lazy val hiveQlTable = new Table(table) + @transient val hiveQlTable = new Table(table) - def hiveQlPartitions = partitions.map { p => + @transient val hiveQlPartitions = partitions.map { p => new Partition(hiveQlTable, p) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 1d9ba1b24a7a4..21ecf17028dbc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -44,6 +44,8 @@ private[hive] case class SourceCommand(filePath: String) extends Command private[hive] case class AddFile(filePath: String) extends Command +private[hive] case class AddJar(path: String) extends Command + private[hive] case class DropTable(tableName: String, ifExists: Boolean) extends Command private[hive] case class AnalyzeTable(tableName: String) extends Command @@ -54,6 +56,7 @@ private[hive] object HiveQl { "TOK_DESCFUNCTION", "TOK_DESCDATABASE", "TOK_SHOW_CREATETABLE", + "TOK_SHOWCOLUMNS", "TOK_SHOW_TABLESTATUS", "TOK_SHOWDATABASES", "TOK_SHOWFUNCTIONS", @@ -61,6 +64,7 @@ private[hive] object HiveQl { "TOK_SHOWINDEXES", "TOK_SHOWPARTITIONS", "TOK_SHOWTABLES", + "TOK_SHOW_TBLPROPERTIES", "TOK_LOCKTABLE", "TOK_SHOWLOCKS", @@ -229,7 +233,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - NativeCommand(sql) + AddJar(sql.trim.drop(8).trim) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { @@ -409,10 +413,9 @@ private[hive] object HiveQl { ExplainCommand(NoRelation) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. - val Some(query) :: _ :: _ :: Nil = + val Some(query) :: _ :: extended :: Nil = getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) - // TODO: support EXTENDED? - ExplainCommand(nodeToPlan(query)) + ExplainCommand(nodeToPlan(query), extended != None) case Token("TOK_DESCTABLE", describeArgs) => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -486,7 +489,7 @@ private[hive] object HiveQl { val (db, tableName) = extractDbNameTableName(tableNameParts) - InsertIntoCreatedTable(db, tableName, nodeToPlan(query)) + CreateTableAsSelect(db, tableName, nodeToPlan(query)) // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. case Token("TOK_CREATETABLE", _) => NativePlaceholder @@ -773,6 +776,7 @@ private[hive] object HiveQl { val joinType = joinToken match { case "TOK_JOIN" => Inner + case "TOK_CROSSJOIN" => Inner case "TOK_RIGHTOUTERJOIN" => RightOuter case "TOK_LEFTOUTERJOIN" => LeftOuter case "TOK_FULLOUTERJOIN" => FullOuter @@ -887,6 +891,7 @@ private[hive] object HiveQl { val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r val SUBSTR = "(?i)SUBSTR(?:ING)?".r + val SQRT = "(?i)SQRT".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -956,6 +961,7 @@ private[hive] object HiveQl { case Token(DIV(), left :: right:: Nil) => Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg)) /* Comparisons */ case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) @@ -1014,9 +1020,9 @@ private[hive] object HiveQl { /* Other functions */ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) /* UDFs - Must be last otherwise will preempt built in functions */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 389ace726d205..43dd3d234f73a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,17 +18,19 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.Experimental -import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.parquet.{ParquetRelation, ParquetTableScan} +import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} +import org.apache.spark.sql.hive +import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.{SQLContext, SchemaRDD} import scala.collection.JavaConversions._ @@ -79,9 +81,9 @@ private[hive] trait HiveStrategies { hiveContext.convertMetastoreParquet => // Filter out all predicates that only deal with partition keys - val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val partitionsKeys = AttributeSet(relation.partitionKeys) val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.map(_.exprId).subsetOf(partitionKeyIds) + _.references.subsetOf(partitionsKeys) } // We are going to throw the predicates and projection back at the whole optimization @@ -135,7 +137,7 @@ private[hive] trait HiveStrategies { .fakeOutput(projectList.map(_.toAttribute)):: Nil } else { hiveContext - .parquetFile(relation.hiveQlTable.getDataLocation.getPath) + .parquetFile(relation.hiveQlTable.getDataLocation.toString) .lowerCase .where(unresolvedOtherPredicates) .select(unresolvedProjection:_*) @@ -163,6 +165,16 @@ private[hive] trait HiveStrategies { InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case logical.CreateTableAsSelect(database, tableName, child) => + val query = planLater(child) + CreateTableAsSelect( + database.get, + tableName, + query, + InsertIntoHiveTable(_: MetastoreRelation, + Map(), + query, + true)(hiveContext)) :: Nil case _ => Nil } } @@ -176,9 +188,9 @@ private[hive] trait HiveStrategies { case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. - val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val partitionKeyIds = AttributeSet(relation.partitionKeys) val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.map(_.exprId).subsetOf(partitionKeyIds) + _.references.subsetOf(partitionKeyIds) } pruneFilterProject( @@ -193,12 +205,13 @@ private[hive] trait HiveStrategies { case class HiveCommandStrategy(context: HiveContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.NativeCommand(sql) => - NativeCommand(sql, plan.output)(context) :: Nil + case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil + + case hive.DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil - case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case hive.AddJar(path) => execution.AddJar(path) :: Nil - case AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil + case hive.AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 82c88280d7754..329f80cad471e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} -import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector @@ -249,6 +249,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { FileInputFormat.setInputPaths(jobConf, path) if (tableDesc != null) { + PlanUtils.configureInputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } val bufferSize = System.getProperty("spark.buffer.size", "65536") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index a013f3f7a805f..6974f3e581b97 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -269,7 +269,74 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |) """.stripMargin.cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd - ) + ), + // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING + // IS NOT YET SUPPORTED + TestTable("episodes_part", + s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) + |PARTITIONED BY (doctor_pt INT) + |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' + |STORED AS + |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' + |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + // WORKAROUND: Required to pass schema to SerDe for partitioned tables. + // TODO: Pass this automatically from the table to partitions. + s""" + |ALTER TABLE episodes_part SET SERDEPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s""" + INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) + SELECT title, air_date, doctor FROM episodes + """.cmd + ) ) hiveQTestUtilTables.foreach(registerTestTable) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala new file mode 100644 index 0000000000000..71ea774d77795 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LowerCaseSchema +import org.apache.spark.sql.execution.{SparkPlan, Command, LeafNode} +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.MetastoreRelation + +/** + * :: Experimental :: + * Create table and insert the query result into it. + * @param database the database name of the new relation + * @param tableName the table name of the new relation + * @param insertIntoRelation function of creating the `InsertIntoHiveTable` + * by specifying the `MetaStoreRelation`, the data will be inserted into that table. + * TODO Add more table creating properties, e.g. SerDe, StorageHandler, in-memory cache etc. + */ +@Experimental +case class CreateTableAsSelect( + database: String, + tableName: String, + query: SparkPlan, + insertIntoRelation: MetastoreRelation => InsertIntoHiveTable) + extends LeafNode with Command { + + def output = Seq.empty + + // A lazy computing of the metastoreRelation + private[this] lazy val metastoreRelation: MetastoreRelation = { + // Create the table + val sc = sqlContext.asInstanceOf[HiveContext] + sc.catalog.createTable(database, tableName, query.output, false) + // Get the Metastore Relation + sc.catalog.lookupRelation(Some(database), tableName, None) match { + case LowerCaseSchema(r: MetastoreRelation) => r + case o: MetastoreRelation => o + } + } + + override protected[sql] lazy val sideEffectResult: Seq[Row] = { + insertIntoRelation(metastoreRelation).execute + Seq.empty[Row] + } + + override def execute(): RDD[Row] = { + sideEffectResult + sparkContext.emptyRDD[Row] + } + + override def argString: String = { + s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]\n" + query.toString + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index a40e89e0d382b..317801001c7a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{Command, LeafNode} import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} @@ -41,26 +41,21 @@ case class DescribeHiveTableCommand( extends LeafNode with Command { // Strings with the format like Hive. It is used for result comparison in our unit tests. - lazy val hiveString: Seq[String] = { - val alignment = 20 - val delim = "\t" - - sideEffectResult.map { - case (name, dataType, comment) => - String.format("%-" + alignment + "s", name) + delim + - String.format("%-" + alignment + "s", dataType) + delim + - String.format("%-" + alignment + "s", Option(comment).getOrElse("None")) - } + lazy val hiveString: Seq[String] = sideEffectResult.map { + case Row(name: String, dataType: String, comment) => + Seq(name, dataType, Option(comment.asInstanceOf[String]).getOrElse("None")) + .map(s => String.format(s"%-20s", s)) + .mkString("\t") } - override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil val columns: Seq[FieldSchema] = table.hiveQlTable.getCols val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols results ++= columns.map(field => (field.getName, field.getType, field.getComment)) - if (!partitionColumns.isEmpty) { + if (partitionColumns.nonEmpty) { val partColumnInfo = partitionColumns.map(field => (field.getName, field.getType, field.getComment)) results ++= @@ -74,14 +69,9 @@ case class DescribeHiveTableCommand( results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) } - results - } - - override def execute(): RDD[Row] = { - val rows = sideEffectResult.map { - case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + results.map { case (name, dataType, comment) => + Row(name, dataType, comment) } - context.sparkContext.parallelize(rows, 1) } override def otherCopyArgs = context :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 39033bdeac4b0..a284a91a91e31 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -53,9 +53,9 @@ case class InsertIntoHiveTable( (@transient sc: HiveContext) extends UnaryNode { - val outputClass = newSerializer(table.tableDesc).getSerializedClass - @transient private val hiveContext = new Context(sc.hiveconf) - @transient private val db = Hive.get(sc.hiveconf) + @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass + @transient private lazy val hiveContext = new Context(sc.hiveconf) + @transient private lazy val db = Hive.get(sc.hiveconf) private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala index fe6031678f70f..8f10e1ba7f426 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala @@ -32,16 +32,7 @@ case class NativeCommand( @transient context: HiveContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) - - override def execute(): RDD[Row] = { - if (sideEffectResult.size == 0) { - context.emptyResult - } else { - val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) - context.sparkContext.parallelize(rows, 1) - } - } + override protected[sql] lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_)) override def otherCopyArgs = context :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 2985169da033c..d61c5e274a596 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -33,19 +33,13 @@ import org.apache.spark.sql.hive.HiveContext */ @DeveloperApi case class AnalyzeTable(tableName: String) extends LeafNode with Command { - def hiveContext = sqlContext.asInstanceOf[HiveContext] def output = Seq.empty - override protected[sql] lazy val sideEffectResult = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { hiveContext.analyze(tableName) - Seq.empty[Any] - } - - override def execute(): RDD[Row] = { - sideEffectResult - sparkContext.emptyRDD[Row] + Seq.empty[Row] } } @@ -55,20 +49,30 @@ case class AnalyzeTable(tableName: String) extends LeafNode with Command { */ @DeveloperApi case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with Command { - def hiveContext = sqlContext.asInstanceOf[HiveContext] def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Any] = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { val ifExistsClause = if (ifExists) "IF EXISTS " else "" hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(None, tableName) - Seq.empty + Seq.empty[Row] } +} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class AddJar(path: String) extends LeafNode with Command { + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + override def output = Seq.empty - override def execute(): RDD[Row] = { - sideEffectResult - sparkContext.emptyRDD[Row] + override protected[sql] lazy val sideEffectResult: Seq[Row] = { + hiveContext.runSqlHive(s"ADD JAR $path") + hiveContext.sparkContext.addJar(path) + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index c6497a15efa0c..7d1ad53d8bdb3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -88,7 +88,6 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu type EvaluatedType = Any def nullable = true - def references = children.flatMap(_.references).toSet lazy val function = createFunction[UDFType]() @@ -229,8 +228,6 @@ private[hive] case class HiveGenericUdaf( def nullable: Boolean = true - def references: Set[Attribute] = children.map(_.references).flatten.toSet - override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" def newInstance() = new HiveUdafFunction(functionClassName, children, this) @@ -253,8 +250,6 @@ private[hive] case class HiveGenericUdtf( children: Seq[Expression]) extends Generator with HiveInspectors with HiveFunctionFactory { - override def references = children.flatMap(_.references).toSet - @transient protected lazy val function: GenericUDTF = createFunction() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala index 544abfc32423c..abed299cd957f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector import org.apache.hadoop.io.Writable /** - * A placeholder that allows SparkSQL users to create metastore tables that are stored as + * A placeholder that allows Spark SQL users to create metastore tables that are stored as * parquet files. It is only intended to pass the checks that the serde is valid and exists * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan * when "spark.sql.hive.convertMetastoreParquet" is set to true. diff --git a/sql/hive/src/test/resources/golden/Read Partitioned with AvroSerDe-0-e4501461c855cc9071a872a64186c3de b/sql/hive/src/test/resources/golden/Read Partitioned with AvroSerDe-0-e4501461c855cc9071a872a64186c3de new file mode 100644 index 0000000000000..49c8434730ffa --- /dev/null +++ b/sql/hive/src/test/resources/golden/Read Partitioned with AvroSerDe-0-e4501461c855cc9071a872a64186c3de @@ -0,0 +1,8 @@ +The Eleventh Hour 3 April 2010 11 1 +The Doctor's Wife 14 May 2011 11 1 +Horror of Fang Rock 3 September 1977 4 1 +An Unearthly Child 23 November 1963 1 1 +The Mysterious Planet 6 September 1986 6 1 +Rose 26 March 2005 9 1 +The Power of the Daleks 5 November 1966 2 1 +Castrolava 4 January 1982 5 1 diff --git a/sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a b/sql/hive/src/test/resources/golden/case sensitivity when query Hive table-0-5d14d21a239daa42b086cc895215009a similarity index 100% rename from sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a rename to sql/hive/src/test/resources/golden/case sensitivity when query Hive table-0-5d14d21a239daa42b086cc895215009a diff --git a/sql/hive/src/test/resources/golden/case when then 1 else null end -0-f7c7fdd35c084bc797890aa08d33693c b/sql/hive/src/test/resources/golden/case when then 1 else null end -0-f7c7fdd35c084bc797890aa08d33693c new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then 1 else null end -0-f7c7fdd35c084bc797890aa08d33693c @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 b/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 new file mode 100644 index 0000000000000..d3827e75a5cad --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 @@ -0,0 +1 @@ +1.0 diff --git a/sql/hive/src/test/resources/golden/case when then 1L else null end -0-763ae85e7a52b4cf4162d6a8931716bb b/sql/hive/src/test/resources/golden/case when then 1L else null end -0-763ae85e7a52b4cf4162d6a8931716bb new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then 1L else null end -0-763ae85e7a52b4cf4162d6a8931716bb @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/case when then 1S else null end -0-6f5f3b3dbe9f1d1eb98443aef315b982 b/sql/hive/src/test/resources/golden/case when then 1S else null end -0-6f5f3b3dbe9f1d1eb98443aef315b982 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then 1S else null end -0-6f5f3b3dbe9f1d1eb98443aef315b982 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/case when then 1Y else null end -0-589982a400d86157791c7216b10b6b5d b/sql/hive/src/test/resources/golden/case when then 1Y else null end -0-589982a400d86157791c7216b10b6b5d new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then 1Y else null end -0-589982a400d86157791c7216b10b6b5d @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/case when then null else 1 end -0-48bd83660cf3ba93cdbdc24559092171 b/sql/hive/src/test/resources/golden/case when then null else 1 end -0-48bd83660cf3ba93cdbdc24559092171 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then null else 1 end -0-48bd83660cf3ba93cdbdc24559092171 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 b/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/case when then null else 1L end -0-a7f1305ea4f86e596c368e35e45cc4e5 b/sql/hive/src/test/resources/golden/case when then null else 1L end -0-a7f1305ea4f86e596c368e35e45cc4e5 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then null else 1L end -0-a7f1305ea4f86e596c368e35e45cc4e5 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/case when then null else 1S end -0-dfb61969e6cb6e6dbe89225b538c8d98 b/sql/hive/src/test/resources/golden/case when then null else 1S end -0-dfb61969e6cb6e6dbe89225b538c8d98 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then null else 1S end -0-dfb61969e6cb6e6dbe89225b538c8d98 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/case when then null else 1Y end -0-7f4c32299c3738739b678ece62752a7b b/sql/hive/src/test/resources/golden/case when then null else 1Y end -0-7f4c32299c3738739b678ece62752a7b new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/case when then null else 1Y end -0-7f4c32299c3738739b678ece62752a7b @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc b/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 b/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 b/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff b/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d b/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d new file mode 100644 index 0000000000000..0cfbf08886fca --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d @@ -0,0 +1 @@ +2 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 b/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e b/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 b/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 new file mode 100644 index 0000000000000..0cfbf08886fca --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 @@ -0,0 +1 @@ +2 diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e b/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e new file mode 100644 index 0000000000000..0cfbf08886fca --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e @@ -0,0 +1 @@ +2 diff --git a/sql/hive/src/test/resources/golden/cross_join-0-7e4af1870bc73decae43b3383c7d2046 b/sql/hive/src/test/resources/golden/cross_join-0-7e4af1870bc73decae43b3383c7d2046 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/cross_join-1-1a96761bf3e47ace9a422ed58273ff35 b/sql/hive/src/test/resources/golden/cross_join-1-1a96761bf3e47ace9a422ed58273ff35 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/cross_join-2-85c93a81eae05bf56a04a904bb80a229 b/sql/hive/src/test/resources/golden/cross_join-2-85c93a81eae05bf56a04a904bb80a229 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-0-d84a430d0ab7a63a0a73361f8d188a4b b/sql/hive/src/test/resources/golden/show_columns-0-d84a430d0ab7a63a0a73361f8d188a4b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-1-ac73cff018bf94944244117a2eb76f7f b/sql/hive/src/test/resources/golden/show_columns-1-ac73cff018bf94944244117a2eb76f7f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-10-695a68c82308540eba1d0a04e032cf39 b/sql/hive/src/test/resources/golden/show_columns-10-695a68c82308540eba1d0a04e032cf39 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-11-691b4e6664e6d435233ea4e8c3b585d5 b/sql/hive/src/test/resources/golden/show_columns-11-691b4e6664e6d435233ea4e8c3b585d5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-12-afc350d459a8f794cc3ca93092163a0c b/sql/hive/src/test/resources/golden/show_columns-12-afc350d459a8f794cc3ca93092163a0c new file mode 100644 index 0000000000000..1711d0c2bb253 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_columns-12-afc350d459a8f794cc3ca93092163a0c @@ -0,0 +1 @@ +a diff --git a/sql/hive/src/test/resources/golden/show_columns-13-e86d559aeb84a4cc017a103182c22bfb b/sql/hive/src/test/resources/golden/show_columns-13-e86d559aeb84a4cc017a103182c22bfb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-14-7c1d63fa270b4d94b69ad49c3e2378a6 b/sql/hive/src/test/resources/golden/show_columns-14-7c1d63fa270b4d94b69ad49c3e2378a6 new file mode 100644 index 0000000000000..1711d0c2bb253 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_columns-14-7c1d63fa270b4d94b69ad49c3e2378a6 @@ -0,0 +1 @@ +a diff --git a/sql/hive/src/test/resources/golden/show_columns-15-2c404655e35f7bd7ce54500c832f9d8e b/sql/hive/src/test/resources/golden/show_columns-15-2c404655e35f7bd7ce54500c832f9d8e new file mode 100644 index 0000000000000..1711d0c2bb253 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_columns-15-2c404655e35f7bd7ce54500c832f9d8e @@ -0,0 +1 @@ +a diff --git a/sql/hive/src/test/resources/golden/show_columns-2-b74990316ec4245fd8a7011e684b39da b/sql/hive/src/test/resources/golden/show_columns-2-b74990316ec4245fd8a7011e684b39da new file mode 100644 index 0000000000000..70c14c3ef34ab --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_columns-2-b74990316ec4245fd8a7011e684b39da @@ -0,0 +1,3 @@ +key +value +ds diff --git a/sql/hive/src/test/resources/golden/show_columns-3-6e40309b0ca10f353a68395ccd64d566 b/sql/hive/src/test/resources/golden/show_columns-3-6e40309b0ca10f353a68395ccd64d566 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-4-a62fc229d241303bffb29b34ad125f8c b/sql/hive/src/test/resources/golden/show_columns-4-a62fc229d241303bffb29b34ad125f8c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-5-691b4e6664e6d435233ea4e8c3b585d5 b/sql/hive/src/test/resources/golden/show_columns-5-691b4e6664e6d435233ea4e8c3b585d5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-6-37c88438bd364343a50f64cf39bfcaf6 b/sql/hive/src/test/resources/golden/show_columns-6-37c88438bd364343a50f64cf39bfcaf6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_columns-7-afc350d459a8f794cc3ca93092163a0c b/sql/hive/src/test/resources/golden/show_columns-7-afc350d459a8f794cc3ca93092163a0c new file mode 100644 index 0000000000000..1711d0c2bb253 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_columns-7-afc350d459a8f794cc3ca93092163a0c @@ -0,0 +1 @@ +a diff --git a/sql/hive/src/test/resources/golden/show_columns-8-9b0b96593ca513c6932f3ed8df68808a b/sql/hive/src/test/resources/golden/show_columns-8-9b0b96593ca513c6932f3ed8df68808a new file mode 100644 index 0000000000000..1711d0c2bb253 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_columns-8-9b0b96593ca513c6932f3ed8df68808a @@ -0,0 +1 @@ +a diff --git a/sql/hive/src/test/resources/golden/show_columns-9-6c0fa8be1c19d4d216dfe7427df1275f b/sql/hive/src/test/resources/golden/show_columns-9-6c0fa8be1c19d4d216dfe7427df1275f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce b/sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-0-ca75bef7d151a44b6a89dd92333eb12a b/sql/hive/src/test/resources/golden/show_tblproperties-0-ca75bef7d151a44b6a89dd92333eb12a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae new file mode 100644 index 0000000000000..0f6cc6f44f1f7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae @@ -0,0 +1 @@ +Table tmpfoo does not have property: bar diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-2-7c7993eea8e41cf095afae07772cc16e b/sql/hive/src/test/resources/golden/show_tblproperties-2-7c7993eea8e41cf095afae07772cc16e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-3-2b4b8c43ef08bdb418405917d475ac1d b/sql/hive/src/test/resources/golden/show_tblproperties-3-2b4b8c43ef08bdb418405917d475ac1d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-4-6c63215ea599f6533666c4d70606b139 b/sql/hive/src/test/resources/golden/show_tblproperties-4-6c63215ea599f6533666c4d70606b139 new file mode 100644 index 0000000000000..ce1a3441a1bc0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_tblproperties-4-6c63215ea599f6533666c4d70606b139 @@ -0,0 +1,6 @@ + +last_modified_by ocquery +last_modified_time 1408598216 +tmp true +transient_lastDdlTime 1408598216 +bar bar value diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-5-be4adb893c7f946ebd76a648ce3cc1ae b/sql/hive/src/test/resources/golden/show_tblproperties-5-be4adb893c7f946ebd76a648ce3cc1ae new file mode 100644 index 0000000000000..37214958dafe5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_tblproperties-5-be4adb893c7f946ebd76a648ce3cc1ae @@ -0,0 +1 @@ +bar value diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-6-9dd8d67460f558955d96a107ca996ad b/sql/hive/src/test/resources/golden/show_tblproperties-6-9dd8d67460f558955d96a107ca996ad new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 7c82964b5ecdc..a35c40efdc207 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.scalatest.BeforeAndAfterAll + import scala.reflect.ClassTag @@ -26,7 +28,9 @@ import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -class StatisticsSuite extends QueryTest { +class StatisticsSuite extends QueryTest with BeforeAndAfterAll { + TestHive.reset() + TestHive.cacheTables = false test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -126,7 +130,7 @@ class StatisticsSuite extends QueryTest { val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } - assert(sizes.size === 1) + assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}") assert(sizes(0).equals(BigInt(5812)), s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") } @@ -146,7 +150,8 @@ class StatisticsSuite extends QueryTest { val sizes = rdd.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold + && sizes(1) <= autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 502ce8fb297e9..671c3b162f875 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -195,6 +195,9 @@ abstract class HiveComparisonTest val installHooksCommand = "(?i)SET.*hooks".r def createQueryTest(testCaseName: String, sql: String, reset: Boolean = true) { + // testCaseName must not contain ':', which is not allowed to appear in a filename of Windows + assert(!testCaseName.contains(":")) + // If test sharding is enable, skip tests that are not in the correct shard. shardInfo.foreach { case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala new file mode 100644 index 0000000000000..4ed58f4be1167 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.Row + +/** + * A set of tests that validates support for Hive Explain command. + */ +class HiveExplainSuite extends QueryTest { + private def check(sqlCmd: String, exists: Boolean, keywords: String*) { + val outputs = sql(sqlCmd).collect().map(_.getString(0)).mkString + for (key <- keywords) { + if (exists) { + assert(outputs.contains(key), s"Failed for $sqlCmd ($key doens't exist in result)") + } else { + assert(!outputs.contains(key), s"Failed for $sqlCmd ($key existed in the result)") + } + } + } + + test("explain extended command") { + check(" explain select * from src where key=123 ", true, + "== Physical Plan ==") + check(" explain select * from src where key=123 ", false, + "== Parsed Logical Plan ==", + "== Analyzed Logical Plan ==", + "== Optimized Logical Plan ==") + check(" explain extended select * from src where key=123 ", true, + "== Parsed Logical Plan ==", + "== Analyzed Logical Plan ==", + "== Optimized Logical Plan ==", + "== Physical Plan ==", + "Code Generation", "== RDD ==") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fdb2f41f5a5b6..6bf8d18a5c32c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import scala.util.Try -import org.apache.spark.sql.{SchemaRDD, Row} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -32,6 +31,71 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("count distinct 0 values", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 'a' AS a FROM src LIMIT 0) table + """.stripMargin) + + createQueryTest("count distinct 1 value strings", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 'a' AS a FROM src LIMIT 1 UNION ALL + | SELECT 'b' AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 1 value", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 1 AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 2 values", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 2 AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 2 values including null", + """ + |SELECT COUNT(DISTINCT a, 1) FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT null AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 1 value + null", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT null AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 1 value long", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1L AS a FROM src LIMIT 1 UNION ALL + | SELECT 1L AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 2 values long", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1L AS a FROM src LIMIT 1 UNION ALL + | SELECT 2L AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 1 value + null long", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1L AS a FROM src LIMIT 1 UNION ALL + | SELECT 1L AS a FROM src LIMIT 1 UNION ALL + | SELECT null AS a FROM src LIMIT 1) table + """.stripMargin) + createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") @@ -244,11 +308,11 @@ class HiveQuerySuite extends HiveComparisonTest { } } - createQueryTest("case sensitivity: Hive table", + createQueryTest("case sensitivity when query Hive table", "SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15") test("case sensitivity: registered table") { - val testData: SchemaRDD = + val testData = TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) @@ -262,7 +326,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.size > 1 && explanation.head.startsWith("Physical execution plan") + explanation.exists(_ == "== Physical Plan ==") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -402,7 +466,7 @@ class HiveQuerySuite extends HiveComparisonTest { } // Describe a registered temporary table. - val testData: SchemaRDD = + val testData = TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) @@ -430,6 +494,45 @@ class HiveQuerySuite extends HiveComparisonTest { } } + test("ADD JAR command") { + val testJar = TestHive.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath + sql("CREATE TABLE alter1(a INT, b INT)") + intercept[Exception] { + sql( + """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' + |WITH serdeproperties('s1'='9') + """.stripMargin) + } + sql(s"ADD JAR $testJar") + sql( + """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' + |WITH serdeproperties('s1'='9') + """.stripMargin) + sql("DROP TABLE alter1") + } + + case class LogEntry(filename: String, message: String) + case class LogFile(name: String) + + test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { + sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") + sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") + + sql( + """ + SELECT name, message + FROM rawLogs + JOIN ( + SELECT name + FROM logFiles + ) files + ON rawLogs.filename = files.name + """).registerTempTable("boom") + + // This should be successfully analyzed + sql("SELECT * FROM boom").queryExecution.analyzed + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -455,62 +558,67 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - + val KV = "([^=]+)=([^=]*)".r + def collectResults(rdd: SchemaRDD): Set[(String, String)] = + rdd.collect().map { + case Row(key: String, value: String) => key -> value + case Row(KV(key, value)) => key -> value + }.toSet clear() // "set" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal)) { + collectResults(hql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal)) { + collectResults(hql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { - sql(s"SET").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + collectResults(hql("SET")) } // "set key" - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal)) { + collectResults(hql(s"SET $testKey")) } - assertResult(Array(s"$nonexistentKey=")) { - sql(s"SET $nonexistentKey").collect().map(_.getString(0)) + assertResult(Set(nonexistentKey -> "")) { + collectResults(hql(s"SET $nonexistentKey")) } // Assert that sql() should have the same effects as sql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal)) { + collectResults(sql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Array(s"$testKey=$testVal")) { - sql("SET").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal)) { + collectResults(sql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { - sql("SET").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + collectResults(sql("SET")) } - assertResult(Array(s"$testKey=$testVal")) { - sql(s"SET $testKey").collect().map(_.getString(0)) + assertResult(Set(testKey -> testVal)) { + collectResults(sql(s"SET $testKey")) } - assertResult(Array(s"$nonexistentKey=")) { - sql(s"SET $nonexistentKey").collect().map(_.getString(0)) + assertResult(Set(nonexistentKey -> "")) { + collectResults(sql(s"SET $nonexistentKey")) } clear() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 6b3ffd1c0ffe2..b6be6bc1bfefe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) case class Nested(a: Int, B: Int) +case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -57,6 +57,13 @@ class HiveResolutionSuite extends HiveComparisonTest { .registerTempTable("caseSensitivityTest") sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + + println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution) + + sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect() + + // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a") + } test("nested repeated resolution") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index df9bae96494d5..7486bfa82b00b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,10 +17,19 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive + /** * A set of tests that validates support for Hive SerDe. */ -class HiveSerDeSuite extends HiveComparisonTest { +class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { + + override def beforeAll() = { + TestHive.cacheTables = false + } + createQueryTest( "Read and write with LazySimpleSerDe (tab separated)", "SELECT * from serdeins") @@ -28,4 +37,6 @@ class HiveSerDeSuite extends HiveComparisonTest { createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") + + createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index c3c18cf8ccac3..48fffe53cf2ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -33,6 +33,12 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } } + val nullVal = "null" + baseTypes.init.foreach { i => + createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") + createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") + } + test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 1a6dbc0ce0c0d..8275e2d3bcce3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.hive.test.TestHive /* Implicit conversions */ @@ -25,9 +27,10 @@ import scala.collection.JavaConversions._ /** * A set of test cases that validate partition and column pruning. */ -class PruningSuite extends HiveComparisonTest { +class PruningSuite extends HiveComparisonTest with BeforeAndAfter { // MINOR HACK: You must run a query before calling reset the first time. TestHive.sql("SHOW TABLES") + TestHive.cacheTables = false // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset // the environment to ensure all referenced tables in this suites are not cached in-memory. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 635a9fb0d56cb..679efe082f2a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.hive.execution -import scala.reflect.ClassTag +import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.Row import org.apache.spark.sql.hive.test.TestHive._ +case class Nested1(f1: Nested2) +case class Nested2(f2: Nested3) +case class Nested3(f3: Int) + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is @@ -47,4 +49,18 @@ class SQLQuerySuite extends QueryTest { GROUP BY key, value ORDER BY value) a""").collect().toSeq) } + + test("double nested data") { + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") + checkAnswer( + sql("SELECT f1.f2.f3 FROM nested"), + 1) + } + + test("test CTAS") { + checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) + checkAnswer( + sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), + sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) + } } diff --git a/streaming/pom.xml b/streaming/pom.xml index ce35520a28609..12f900c91eb98 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 101cec1c7a7c2..f63560dcb5b89 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorSupervisorStrategy, ActorReceiver, Receiver} import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.ui.StreamingTab +import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.MetadataCleaner /** @@ -158,7 +158,14 @@ class StreamingContext private[streaming] ( private[streaming] val waiter = new ContextWaiter - private[streaming] val uiTab = new StreamingTab(this) + private[streaming] val progressListener = new StreamingJobProgressListener(this) + + private[streaming] val uiTab: Option[StreamingTab] = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(new StreamingTab(this)) + } else { + None + } /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) @@ -240,7 +247,7 @@ class StreamingContext private[streaming] ( * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param props Props object defining creation of the actor * @param name Name of the actor - * @param storageLevel RDD storage level. Defaults to memory-only. + * @param storageLevel RDD storage level (default: StorageLevel.MEMORY_AND_DISK_SER_2) * * @note An important point to note: * Since Actor may exist outside the spark framework, It is thus user's responsibility diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 75f0e8716dc7e..e35a568ddf115 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -26,7 +26,7 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { override val metricRegistry = new MetricRegistry override val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) - private val streamingListener = ssc.uiTab.listener + private val streamingListener = ssc.progressListener private def registerGauge[T](name: String, f: StreamingJobProgressListener => T, defaultValue: T) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index d934b9cbfc3e8..53a3e6200e340 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -20,22 +20,21 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await import akka.actor.{Actor, Props} import akka.pattern.ask +import com.google.common.base.Throwables + import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.scheduler.DeregisterReceiver import org.apache.spark.streaming.scheduler.AddBlock -import scala.Some import org.apache.spark.streaming.scheduler.RegisterReceiver -import com.google.common.base.Throwables /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -56,7 +55,8 @@ private[streaming] class ReceiverSupervisorImpl( private val trackerActor = { val ip = env.conf.get("spark.driver.host", "localhost") val port = env.conf.getInt("spark.driver.port", 7077) - val url = "akka.tcp://spark@%s:%s/user/ReceiverTracker".format(ip, port) + val url = "akka.tcp://%s@%s:%s/user/ReceiverTracker".format( + SparkEnv.driverActorSystemName, ip, port) env.actorSystem.actorSelection(url) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 34ac254f337eb..d9d04cd706a04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,18 +17,31 @@ package org.apache.spark.streaming.ui -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.SparkUITab +import org.apache.spark.ui.{SparkUI, SparkUITab} -/** Spark Web UI tab that shows statistics of a streaming job */ +import StreamingTab._ + +/** + * Spark Web UI tab that shows statistics of a streaming job. + * This assumes the given SparkContext has enabled its SparkUI. + */ private[spark] class StreamingTab(ssc: StreamingContext) - extends SparkUITab(ssc.sc.ui, "streaming") with Logging { + extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { - val parent = ssc.sc.ui - val listener = new StreamingJobProgressListener(ssc) + val parent = getSparkUI(ssc) + val listener = ssc.progressListener ssc.addStreamingListener(listener) attachPage(new StreamingPage(this)) parent.attachTab(this) } + +private object StreamingTab { + def getSparkUI(ssc: StreamingContext): SparkUI = { + ssc.sc.ui.getOrElse { + throw new SparkException("Parent SparkUI to attach this tab to not found!") + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index ff6d86c8f81ac..059ac6c2dbee2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -452,7 +452,7 @@ class BasicOperationsSuite extends TestSuiteBase { test("rdd cleanup - updateStateByKey") { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some(values.foldLeft(0)(_ + _) + state.getOrElse(0)) + Some(values.sum + state.getOrElse(0)) } val stateStream = runCleanupTest( conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3))) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 10ad3c9e1adc9..8511390cb1ad5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -70,7 +70,7 @@ class CheckpointSuite extends TestSuiteBase { val input = (1 to 10).map(_ => Seq("a")).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.foldLeft(0)(_ + _) + state.getOrElse(0))) + Some((values.sum + state.getOrElse(0))) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) @@ -214,7 +214,7 @@ class CheckpointSuite extends TestSuiteBase { val output = (1 to 10).map(x => Seq(("a", x))).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.foldLeft(0)(_ + _) + state.getOrElse(0))) + Some((values.sum + state.getOrElse(0))) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index f4e11f975de94..99c8d13231aac 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.language.postfixOps import org.apache.spark.SparkConf import org.apache.spark.storage.{StorageLevel, StreamBlockId} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 7b33d3b235466..a3cabd6be02fe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -29,8 +29,6 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import scala.language.postfixOps - class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 2861f5335ae36..84fed95a75e67 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming import scala.collection.mutable.ArrayBuffer import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global -import scala.language.postfixOps import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala index 2a0db7564915d..8e30118266855 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala @@ -18,19 +18,27 @@ package org.apache.spark.streaming import scala.io.Source -import scala.language.postfixOps import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf + class UISuite extends FunSuite { // Ignored: See SPARK-1530 ignore("streaming tab in spark UI") { - val ssc = new StreamingContext("local", "test", Seconds(1)) + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val ssc = new StreamingContext(conf, Seconds(1)) + assert(ssc.sc.ui.isDefined, "Spark UI is not started!") + val ui = ssc.sc.ui.get + eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ssc.sparkContext.ui.appUIAddress).mkString + val html = Source.fromURL(ui.appUIAddress).mkString assert(!html.contains("random data that should not be present")) // test if streaming tab exist assert(html.toLowerCase.contains("streaming")) @@ -39,8 +47,7 @@ class UISuite extends FunSuite { } eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL( - ssc.sparkContext.ui.appUIAddress.stripSuffix("/") + "/streaming").mkString + val html = Source.fromURL(ui.appUIAddress.stripSuffix("/") + "/streaming").mkString assert(html.toLowerCase.contains("batch")) assert(html.toLowerCase.contains("network")) } diff --git a/tools/pom.xml b/tools/pom.xml index 97abb6b2b63e0..f36674476770c 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 17bf7c2541d13..db58eb642b56d 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -20,10 +20,11 @@ package org.apache.spark.tools import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.SparkContext import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util.Utils -import org.apache.spark.executor.ShuffleWriteMetrics /** * Internal utility for micro-benchmarking shuffle write performance. @@ -50,13 +51,15 @@ object StoragePerfTester { System.setProperty("spark.shuffle.compress", "false") System.setProperty("spark.shuffle.sync", "true") + System.setProperty("spark.shuffle.manager", + "org.apache.spark.shuffle.hash.HashShuffleManager") // This is only used to instantiate a BlockManager. All thread scheduling is done manually. val sc = new SparkContext("local[4]", "Write Tester") - val blockManager = sc.env.blockManager + val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] def writeOutputBytes(mapId: Int, total: AtomicLong) = { - val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, + val shuffle = hashShuffleManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { diff --git a/tox.ini b/tox.ini index a1fefdd0e176f..b568029a204cc 100644 --- a/tox.ini +++ b/tox.ini @@ -15,4 +15,4 @@ [pep8] max-line-length=100 -exclude=cloudpickle.py +exclude=cloudpickle.py,heapq3.py diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 51744ece0412d..7dadbba58fd82 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala deleted file mode 100644 index 4d4848b1bd8f8..0000000000000 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ /dev/null @@ -1,453 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn - -import java.io.IOException -import java.net.Socket -import java.util.concurrent.CopyOnWriteArrayList -import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.util.ShutdownHookManager -import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.api.protocolrecords._ -import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{ConverterUtils, Records} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.{SignalLogger, Utils} - -/** - * An application master that runs the users driver program and allocates executors. - */ -class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, - sparkConf: SparkConf) extends Logging { - - def this(args: ApplicationMasterArguments, sparkConf: SparkConf) = - this(args, new Configuration(), sparkConf) - - def this(args: ApplicationMasterArguments) = this(args, new SparkConf()) - - private val rpc: YarnRPC = YarnRPC.create(conf) - private var resourceManager: AMRMProtocol = _ - private var appAttemptId: ApplicationAttemptId = _ - private var userThread: Thread = _ - private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - private val fs = FileSystem.get(yarnConf) - - private var yarnAllocator: YarnAllocationHandler = _ - private var isFinished: Boolean = false - private var uiAddress: String = _ - private var uiHistoryAddress: String = _ - private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, - YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) - private var isLastAMRetry: Boolean = true - - // Default to numExecutors * 2, with minimum of 3 - private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) - - private var registered = false - - def run() { - // set the web ui port to be ephemeral for yarn so we don't conflict with - // other spark processes running on the same box - System.setProperty("spark.ui.port", "0") - - // when running the AM, the Spark master is always "yarn-cluster" - System.setProperty("spark.master", "yarn-cluster") - - // Use priority 30 as its higher then HDFS. Its same priority as MapReduce is using. - ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) - - appAttemptId = getApplicationAttemptId() - isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts - resourceManager = registerWithResourceManager() - - // setup AmIpFilter for the SparkUI - do this before we start the UI - addAmIpFilter() - - ApplicationMaster.register(this) - - // Call this to force generation of secret so it gets populated into the - // hadoop UGI. This has to happen before the startUserClass which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) - - // Start the user's JAR - userThread = startUserClass() - - // This a bit hacky, but we need to wait until the spark.driver.port property has - // been set by the Thread executing the user class. - waitForSparkContextInitialized() - - // Do this after spark master is up and SparkContext is created so that we can register UI Url - synchronized { - if (!isFinished) { - registerApplicationMaster() - registered = true - } - } - - // Allocate all containers - allocateExecutors() - - // Wait for the user class to Finish - userThread.join() - - System.exit(0) - } - - // add the yarn amIpFilter that Yarn requires for properly securing the UI - private def addAmIpFilter() { - val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" - System.setProperty("spark.ui.filters", amFilter) - val proxy = YarnConfiguration.getProxyHostAndPort(conf) - val parts : Array[String] = proxy.split(":") - val uriBase = "http://" + proxy + - System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) - - val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase - System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", - params) - } - - private def getApplicationAttemptId(): ApplicationAttemptId = { - val envs = System.getenv() - val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) - val containerId = ConverterUtils.toContainerId(containerIdString) - val appAttemptId = containerId.getApplicationAttemptId() - logInfo("ApplicationAttemptId: " + appAttemptId) - appAttemptId - } - - private def registerWithResourceManager(): AMRMProtocol = { - val rmAddress = NetUtils.createSocketAddr(yarnConf.get( - YarnConfiguration.RM_SCHEDULER_ADDRESS, - YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) - logInfo("Connecting to ResourceManager at " + rmAddress) - rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] - } - - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { - logInfo("Registering the ApplicationMaster") - val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) - .asInstanceOf[RegisterApplicationMasterRequest] - appMasterRequest.setApplicationAttemptId(appAttemptId) - // Setting this to master host,port - so that the ApplicationReport at client has some - // sensible info. - // Users can then monitor stderr/stdout on that node if required. - appMasterRequest.setHost(Utils.localHostName()) - appMasterRequest.setRpcPort(0) - appMasterRequest.setTrackingUrl(uiAddress) - resourceManager.registerApplicationMaster(appMasterRequest) - } - - private def startUserClass(): Thread = { - logInfo("Starting the user JAR in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) - val mainMethod = Class.forName( - args.userClass, - false /* initialize */ , - Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) - val t = new Thread { - override def run() { - - var successed = false - try { - // Copy - var mainArgs: Array[String] = new Array[String](args.userArgs.size) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) - mainMethod.invoke(null, mainArgs) - // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR - // userThread will stop here unless it has uncaught exception thrown out - // It need shutdown hook to set SUCCEEDED - successed = true - } finally { - logDebug("finishing main") - isLastAMRetry = true - if (successed) { - ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) - } else { - ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED) - } - } - } - } - t.start() - t - } - - // this need to happen before allocateExecutors - private def waitForSparkContextInitialized() { - logInfo("Waiting for spark context initialization") - try { - var sparkContext: SparkContext = null - ApplicationMaster.sparkContextRef.synchronized { - var count = 0 - val waitTime = 10000L - val numTries = sparkConf.getInt("spark.yarn.ApplicationMaster.waitTries", 10) - while (ApplicationMaster.sparkContextRef.get() == null && count < numTries - && !isFinished) { - logInfo("Waiting for spark context initialization ... " + count) - count = count + 1 - ApplicationMaster.sparkContextRef.wait(waitTime) - } - sparkContext = ApplicationMaster.sparkContextRef.get() - assert(sparkContext != null || count >= numTries) - - if (null != sparkContext) { - uiAddress = sparkContext.ui.appUIHostPort - uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf) - this.yarnAllocator = YarnAllocationHandler.newAllocator( - yarnConf, - resourceManager, - appAttemptId, - args, - sparkContext.preferredNodeLocationData, - sparkContext.getConf) - } else { - logWarning("Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d". - format(count * waitTime, numTries)) - this.yarnAllocator = YarnAllocationHandler.newAllocator( - yarnConf, - resourceManager, - appAttemptId, - args, - sparkContext.getConf) - } - } - } - } - - private def allocateExecutors() { - try { - logInfo("Allocating " + args.numExecutors + " executors.") - // Wait until all containers have finished - // TODO: This is a bit ugly. Can we make it nicer? - // TODO: Handle container failure - - // Exits the loop if the user thread exits. - while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive - && !isFinished) { - checkNumExecutorsFailed() - yarnAllocator.allocateContainers( - math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) - Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) - } - } - logInfo("All executors have launched.") - - // Launch a progress reporter thread, else the app will get killed after expiration - // (def: 10mins) timeout. - // TODO(harvey): Verify the timeout - if (userThread.isAlive) { - // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. - val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - - // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) - - // must be <= timeoutInterval / 2. - val interval = math.min(timeoutInterval / 2, schedulerInterval) - - launchReporterThread(interval) - } - } - - private def launchReporterThread(_sleepTime: Long): Thread = { - val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime - - val t = new Thread { - override def run() { - while (userThread.isAlive && !isFinished) { - checkNumExecutorsFailed() - val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - if (missingExecutorCount > 0) { - logInfo("Allocating %d containers to make up for (potentially) lost containers". - format(missingExecutorCount)) - yarnAllocator.allocateContainers(missingExecutorCount) - } else { - sendProgress() - } - Thread.sleep(sleepTime) - } - } - } - // Setting to daemon status, though this is usually not a good idea. - t.setDaemon(true) - t.start() - logInfo("Started progress reporter thread - sleep time : " + sleepTime) - t - } - - private def checkNumExecutorsFailed() { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - logInfo("max number of executor failures reached") - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - // make sure to stop the user thread - val sparkContext = ApplicationMaster.sparkContextRef.get() - if (sparkContext != null) { - logInfo("Invoking sc stop from checkNumExecutorsFailed") - sparkContext.stop() - } else { - logError("sparkContext is null when should shutdown") - } - } - } - - private def sendProgress() { - logDebug("Sending progress") - // Simulated with an allocate request with no nodes requested ... - yarnAllocator.allocateContainers(0) - } - - /* - def printContainers(containers: List[Container]) = { - for (container <- containers) { - logInfo("Launching shell command on a new container." - + ", containerId=" + container.getId() - + ", containerNode=" + container.getNodeId().getHost() - + ":" + container.getNodeId().getPort() - + ", containerNodeURI=" + container.getNodeHttpAddress() - + ", containerState" + container.getState() - + ", containerResourceMemory" - + container.getResource().getMemory()) - } - } - */ - - def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") { - synchronized { - if (isFinished) { - return - } - isFinished = true - - logInfo("finishApplicationMaster with " + status) - if (registered) { - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(appAttemptId) - finishReq.setFinishApplicationStatus(status) - finishReq.setDiagnostics(diagnostics) - finishReq.setTrackingUrl(uiHistoryAddress) - resourceManager.finishApplicationMaster(finishReq) - } - } - } - - /** - * Clean up the staging directory. - */ - private def cleanupStagingDir() { - var stagingDirPath: Path = null - try { - val preserveFiles = sparkConf.get("spark.yarn.preserve.staging.files", "false").toBoolean - if (!preserveFiles) { - stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) - if (stagingDirPath == null) { - logError("Staging directory is null") - return - } - logInfo("Deleting staging directory " + stagingDirPath) - fs.delete(stagingDirPath, true) - } - } catch { - case ioe: IOException => - logError("Failed to cleanup staging dir " + stagingDirPath, ioe) - } - } - - // The shutdown hook that runs when a signal is received AND during normal close of the JVM. - class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable { - - def run() { - logInfo("AppMaster received a signal.") - // we need to clean up staging dir before HDFS is shut down - // make sure we don't delete it until this is the last AM - if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir() - } - } - -} - -object ApplicationMaster extends Logging { - // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be - // optimal as more containers are available. Might need to handle this better. - private val ALLOCATE_HEARTBEAT_INTERVAL = 100 - - private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() - - def register(master: ApplicationMaster) { - applicationMasters.add(master) - } - - val sparkContextRef: AtomicReference[SparkContext] = - new AtomicReference[SparkContext](null /* initialValue */) - - def sparkContextInitialized(sc: SparkContext): Boolean = { - var modified = false - sparkContextRef.synchronized { - modified = sparkContextRef.compareAndSet(null, sc) - sparkContextRef.notifyAll() - } - - // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do - // System.exit. - // Should not really have to do this, but it helps YARN to evict resources earlier. - // Not to mention, prevent the Client from declaring failure even though we exited properly. - // Note that this will unfortunately not properly clean up the staging files because it gets - // called too late, after the filesystem is already shutdown. - if (modified) { - Runtime.getRuntime().addShutdownHook(new Thread with Logging { - // This is not only logs, but also ensures that log system is initialized for this instance - // when we are actually 'run'-ing. - logInfo("Adding shutdown hook for context " + sc) - - override def run() { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - // Best case ... - for (master <- applicationMasters) { - master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) - } - } - }) - } - - modified - } - - def main(argStrings: Array[String]) { - SignalLogger.register(log) - val args = new ApplicationMasterArguments(argStrings) - SparkHadoopUtil.get.runAsSparkUser { () => - new ApplicationMaster(args).run() - } - } -} diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 82f79d88a3009..aff9ab71f0937 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, Records} import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil /** * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's alpha API. @@ -40,7 +41,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa extends YarnClientImpl with ClientBase with Logging { def this(clientArgs: ClientArguments, spConf: SparkConf) = - this(clientArgs, new Configuration(), spConf) + this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) @@ -89,17 +90,8 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def logClusterResourceDetails() { val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics - logInfo("Got Cluster metric info from ASM, numNodeManagers = " + + logInfo("Got cluster metric info from ASM, numNodeManagers = " + clusterMetrics.getNumNodeManagers) - - val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue) - logInfo( """Queue info ... queueName = %s, queueCurrentCapacity = %s, queueMaxCapacity = %s, - queueApplicationCount = %s, queueChildQueueCount = %s""".format( - queueInfo.getQueueName, - queueInfo.getCurrentCapacity, - queueInfo.getMaximumCapacity, - queueInfo.getApplications.size, - queueInfo.getChildQueues.size)) } @@ -111,14 +103,6 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa appContext } - def calculateAMMemory(newApp: GetNewApplicationResponse): Int = { - val minResMemory = newApp.getMinimumResourceCapability().getMemory() - val amMemory = ((args.amMemory / minResMemory) * minResMemory) + - ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - memoryOverhead) - amMemory - } - def setupSecurityToken(amContainer: ContainerLaunchContext) = { // Setup security tokens. val dob = new DataOutputBuffer() diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala deleted file mode 100644 index c3310fbc24a98..0000000000000 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ /dev/null @@ -1,312 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn - -import java.net.Socket -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.api.protocolrecords._ -import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import akka.actor._ -import akka.remote._ -import org.apache.spark.{Logging, SecurityManager, SparkConf} -import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter -import org.apache.spark.scheduler.SplitInfo -import org.apache.spark.deploy.SparkHadoopUtil - -/** - * An application master that allocates executors on behalf of a driver that is running outside - * the cluster. - * - * This is used only in yarn-client mode. - */ -class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sparkConf: SparkConf) - extends Logging { - - def this(args: ApplicationMasterArguments, sparkConf: SparkConf) = - this(args, new Configuration(), sparkConf) - - def this(args: ApplicationMasterArguments) = this(args, new SparkConf()) - - private val rpc: YarnRPC = YarnRPC.create(conf) - private var resourceManager: AMRMProtocol = _ - private var appAttemptId: ApplicationAttemptId = _ - private var reporterThread: Thread = _ - private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - - private var yarnAllocator: YarnAllocationHandler = _ - - private var driverClosed: Boolean = false - private var isFinished: Boolean = false - private var registered: Boolean = false - - // Default to numExecutors * 2, with minimum of 3 - private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) - - val securityManager = new SecurityManager(sparkConf) - val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityManager)._1 - var actor: ActorRef = _ - - // This actor just working as a monitor to watch on Driver Actor. - class MonitorActor(driverUrl: String) extends Actor { - - var driver: ActorSelection = _ - - override def preStart() { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message thus the connection is actually established, thus we can - // monitor Lifecycle Events. - driver ! "Hello" - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } - - override def receive = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - driverClosed = true - case x: AddWebUIFilter => - logInfo(s"Add WebUI Filter. $x") - driver ! x - } - } - - def run() { - appAttemptId = getApplicationAttemptId() - resourceManager = registerWithResourceManager() - - synchronized { - if (!isFinished) { - val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() - // Compute number of threads for akka - val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() - - if (minimumMemory > 0) { - val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnAllocationHandler.MEMORY_OVERHEAD) - val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) - - if (numCore > 0) { - // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 - // TODO: Uncomment when hadoop is on a version which has this fixed. - // args.workerCores = numCore - } - } - registered = true - } - } - waitForSparkMaster() - addAmIpFilter() - // Allocate all containers - allocateExecutors() - - // Launch a progress reporter thread, else app will get killed after expiration - // (def: 10mins) timeout ensure that progress is sent before - // YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. - - val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong - - // must be <= timeoutInterval / 2. - val interval = math.min(timeoutInterval / 2, schedulerInterval) - - reporterThread = launchReporterThread(interval) - - // Wait for the reporter thread to Finish. - reporterThread.join() - - finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) - actorSystem.shutdown() - - logInfo("Exited") - System.exit(0) - } - - private def getApplicationAttemptId(): ApplicationAttemptId = { - val envs = System.getenv() - val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) - val containerId = ConverterUtils.toContainerId(containerIdString) - val appAttemptId = containerId.getApplicationAttemptId() - logInfo("ApplicationAttemptId: " + appAttemptId) - appAttemptId - } - - private def registerWithResourceManager(): AMRMProtocol = { - val rmAddress = NetUtils.createSocketAddr(yarnConf.get( - YarnConfiguration.RM_SCHEDULER_ADDRESS, - YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) - logInfo("Connecting to ResourceManager at " + rmAddress) - rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] - } - - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { - val appUIAddress = sparkConf.get("spark.driver.appUIAddress", "") - logInfo(s"Registering the ApplicationMaster with appUIAddress: $appUIAddress") - val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) - .asInstanceOf[RegisterApplicationMasterRequest] - appMasterRequest.setApplicationAttemptId(appAttemptId) - // Setting this to master host,port - so that the ApplicationReport at client has - // some sensible info. Users can then monitor stderr/stdout on that node if required. - appMasterRequest.setHost(Utils.localHostName()) - appMasterRequest.setRpcPort(0) - // What do we provide here ? Might make sense to expose something sensible later ? - appMasterRequest.setTrackingUrl(appUIAddress) - resourceManager.registerApplicationMaster(appMasterRequest) - } - - // add the yarn amIpFilter that Yarn requires for properly securing the UI - private def addAmIpFilter() { - val proxy = YarnConfiguration.getProxyHostAndPort(conf) - val parts = proxy.split(":") - val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) - val uriBase = "http://" + proxy + proxyBase - val amFilter = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase - val amFilterName = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" - actor ! AddWebUIFilter(amFilterName, amFilter, proxyBase) - } - - private def waitForSparkMaster() { - logInfo("Waiting for spark driver to be reachable.") - var driverUp = false - val hostport = args.userArgs(0) - val (driverHost, driverPort) = Utils.parseHostPort(hostport) - while(!driverUp) { - try { - val socket = new Socket(driverHost, driverPort) - socket.close() - logInfo("Master now available: " + driverHost + ":" + driverPort) - driverUp = true - } catch { - case e: Exception => - logError("Failed to connect to driver at " + driverHost + ":" + driverPort) - Thread.sleep(100) - } - } - sparkConf.set("spark.driver.host", driverHost) - sparkConf.set("spark.driver.port", driverPort.toString) - - val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME) - - actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") - } - - - private def allocateExecutors() { - - // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. - val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = - scala.collection.immutable.Map() - - yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, - args, preferredNodeLocationData, sparkConf) - - logInfo("Allocating " + args.numExecutors + " executors.") - // Wait until all containers have finished - // TODO: This is a bit ugly. Can we make it nicer? - // TODO: Handle container failure - while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed) && - !isFinished) { - yarnAllocator.allocateContainers( - math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) - checkNumExecutorsFailed() - Thread.sleep(100) - } - - logInfo("All executors have launched.") - } - private def checkNumExecutorsFailed() { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } - } - - // TODO: We might want to extend this to allocate more containers in case they die ! - private def launchReporterThread(_sleepTime: Long): Thread = { - val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime - - val t = new Thread { - override def run() { - while (!driverClosed && !isFinished) { - checkNumExecutorsFailed() - val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - if (missingExecutorCount > 0) { - logInfo("Allocating " + missingExecutorCount + - " containers to make up for (potentially ?) lost containers") - yarnAllocator.allocateContainers(missingExecutorCount) - } else { - sendProgress() - } - Thread.sleep(sleepTime) - } - } - } - // setting to daemon status, though this is usually not a good idea. - t.setDaemon(true) - t.start() - logInfo("Started progress reporter thread - sleep time : " + sleepTime) - t - } - - private def sendProgress() { - logDebug("Sending progress") - // simulated with an allocate request with no nodes requested ... - yarnAllocator.allocateContainers(0) - } - - def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") { - synchronized { - if (isFinished) { - return - } - logInfo("Unregistering ApplicationMaster with " + status) - if (registered) { - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(appAttemptId) - finishReq.setFinishApplicationStatus(status) - finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) - finishReq.setDiagnostics(appMessage) - resourceManager.finishApplicationMaster(finishReq) - } - isFinished = true - } - } - -} - - -object ExecutorLauncher { - def main(argStrings: Array[String]) { - val args = new ApplicationMasterArguments(argStrings) - SparkHadoopUtil.get.runAsSparkUser { () => - new ExecutorLauncher(args).run() - } - } -} diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 7dae248e3e7db..10cbeb8b94325 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SecurityManager, SparkConf, Logging} class ExecutorRunnable( @@ -46,7 +46,8 @@ class ExecutorRunnable( slaveId: String, hostname: String, executorMemory: Int, - executorCores: Int) + executorCores: Int, + securityMgr: SecurityManager) extends Runnable with ExecutorRunnableUtil with Logging { var rpc: YarnRPC = YarnRPC.create(conf) @@ -86,6 +87,8 @@ class ExecutorRunnable( logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) + ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // Send the start request to the ContainerManager val startReq = Records.newRecord(classOf[StartContainerRequest]) .asInstanceOf[StartContainerRequest] diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 80e0162e9f277..5a1b42c1e17d5 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -17,397 +17,47 @@ package org.apache.spark.deploy.yarn -import java.lang.{Boolean => JBoolean} -import java.util.{Collections, Set => JSet} -import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} +import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.atomic.AtomicInteger -import scala.collection import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.scheduler.SplitInfo import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.AMRMProtocol -import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId} -import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus} -import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest} -import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} -import org.apache.hadoop.yarn.util.{RackResolver, Records} - - -object AllocationType extends Enumeration { - type AllocationType = Value - val HOST, RACK, ANY = Value -} - -// TODO: -// Too many params. -// Needs to be mt-safe -// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should -// make it more proactive and decoupled. - -// Note that right now, we assume all node asks as uniform in terms of capabilities and priority -// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for -// more info on how we are requesting for containers. +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest +import org.apache.hadoop.yarn.util.Records /** * Acquires resources for executors from a ResourceManager and launches executors in new containers. */ private[yarn] class YarnAllocationHandler( - val conf: Configuration, - val resourceManager: AMRMProtocol, - val appAttemptId: ApplicationAttemptId, - val maxExecutors: Int, - val executorMemory: Int, - val executorCores: Int, - val preferredHostToCount: Map[String, Int], - val preferredRackToCount: Map[String, Int], - val sparkConf: SparkConf) - extends Logging { - // These three are locked on allocatedHostToContainersMap. Complementary data structures - // allocatedHostToContainersMap : containers which are running : host, Set - // allocatedContainerToHostMap: container to host mapping. - private val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]]() - - private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() - - // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an - // allocated node) - // As with the two data structures above, tightly coupled with them, and to be locked on - // allocatedHostToContainersMap - private val allocatedRackCount = new HashMap[String, Int]() - - // Containers which have been released. - private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() - // Containers to be released in next request to RM - private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] - - // Additional memory overhead - in mb. - private def memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnAllocationHandler.MEMORY_OVERHEAD) - - private val numExecutorsRunning = new AtomicInteger() - // Used to generate a unique id per executor - private val executorIdCounter = new AtomicInteger() - private val lastResponseId = new AtomicInteger() - private val numExecutorsFailed = new AtomicInteger() - - def getNumExecutorsRunning: Int = numExecutorsRunning.intValue - - def getNumExecutorsFailed: Int = numExecutorsFailed.intValue - - def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + memoryOverhead) - } - - def allocateContainers(executorsToRequest: Int) { - // We need to send the request only once from what I understand ... but for now, not modifying - // this much. - - // Keep polling the Resource Manager for containers - val amResp = allocateExecutorResources(executorsToRequest).getAMResponse - - val _allocatedContainers = amResp.getAllocatedContainers() - - if (_allocatedContainers.size > 0) { - logDebug(""" - Allocated containers: %d - Current executor count: %d - Containers released: %s - Containers to be released: %s - Cluster resources: %s - """.format( - _allocatedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers, - amResp.getAvailableResources)) - - val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() - - // Ignore if not satisfying constraints { - for (container <- _allocatedContainers) { - if (isResourceConstraintSatisfied(container)) { - // allocatedContainers += container - - val host = container.getNodeId.getHost - val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]()) - - containers += container - } else { - // Add all ignored containers to released list - releasedContainerList.add(container.getId()) - } - } - - // Find the appropriate containers to use. Slightly non trivial groupBy ... - val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (candidateHost <- hostToContainers.keySet) - { - val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) - val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) - - var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null) - assert(remainingContainers != null) - - if (requiredHostCount >= remainingContainers.size){ - // Since we got <= required containers, add all to dataLocalContainers - dataLocalContainers.put(candidateHost, remainingContainers) - // all consumed - remainingContainers = null - } else if (requiredHostCount > 0) { - // Container list has more containers than we need for data locality. - // Split into two : data local container count of (remainingContainers.size - - // requiredHostCount) and rest as remainingContainer - val (dataLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredHostCount) - dataLocalContainers.put(candidateHost, dataLocal) - // remainingContainers = remaining - - // yarn has nasty habit of allocating a tonne of containers on a host - discourage this : - // add remaining to release list. If we have insufficient containers, next allocation - // cycle will reallocate (but wont treat it as data local) - for (container <- remaining) releasedContainerList.add(container.getId()) - remainingContainers = null - } - - // Now rack local - if (remainingContainers != null){ - val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) - - if (rack != null){ - val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) - val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - - rackLocalContainers.get(rack).getOrElse(List()).size - - - if (requiredRackCount >= remainingContainers.size){ - // Add all to dataLocalContainers - dataLocalContainers.put(rack, remainingContainers) - // All consumed - remainingContainers = null - } else if (requiredRackCount > 0) { - // container list has more containers than we need for data locality. - // Split into two : data local container count of (remainingContainers.size - - // requiredRackCount) and rest as remainingContainer - val (rackLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredRackCount) - val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, - new ArrayBuffer[Container]()) - - existingRackLocal ++= rackLocal - remainingContainers = remaining - } - } - } - - // If still not consumed, then it is off rack host - add to that list. - if (remainingContainers != null){ - offRackContainers.put(candidateHost, remainingContainers) - } - } - - // Now that we have split the containers into various groups, go through them in order : - // first host local, then rack local and then off rack (everything else). - // Note that the list we create below tries to ensure that not all containers end up within a - // host if there are sufficiently large number of hosts/containers. - - val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) - allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) - allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) - allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) - - // Run each of the allocated containers - for (container <- allocatedContainers) { - val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() - val executorHostname = container.getNodeId.getHost - val containerId = container.getId - - assert( container.getResource.getMemory >= - (executorMemory + memoryOverhead)) - - if (numExecutorsRunningNow > maxExecutors) { - logInfo("""Ignoring container %s at host %s, since we already have the required number of - containers for it.""".format(containerId, executorHostname)) - releasedContainerList.add(containerId) - // reset counter back to old value. - numExecutorsRunning.decrementAndGet() - } else { - // Deallocate + allocate can result in reusing id's wrongly - so use a different counter - // (executorIdCounter) - val executorId = executorIdCounter.incrementAndGet().toString - val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - - logInfo("launching container on " + containerId + " host " + executorHostname) - // Just to be safe, simply remove it from pendingReleaseContainers. - // Should not be there, but .. - pendingReleaseContainers.remove(containerId) - - val rack = YarnAllocationHandler.lookupRack(conf, executorHostname) - allocatedHostToContainersMap.synchronized { - val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, - new HashSet[ContainerId]()) - - containerSet += containerId - allocatedContainerToHostMap.put(containerId, executorHostname) - if (rack != null) { - allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) - } - } - - new Thread( - new ExecutorRunnable(container, conf, sparkConf, driverUrl, executorId, - executorHostname, executorMemory, executorCores) - ).start() - } - } - logDebug(""" - Finished processing %d containers. - Current number of executors running: %d, - releasedContainerList: %s, - pendingReleaseContainers: %s - """.format( - allocatedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers)) - } - - - val completedContainers = amResp.getCompletedContainersStatuses() - if (completedContainers.size > 0){ - logDebug("Completed %d containers, to-be-released: %s".format( - completedContainers.size, releasedContainerList)) - for (completedContainer <- completedContainers){ - val containerId = completedContainer.getContainerId - - // Was this released by us ? If yes, then simply remove from containerSet and move on. - if (pendingReleaseContainers.containsKey(containerId)) { - pendingReleaseContainers.remove(containerId) - } else { - // Simply decrement count - next iteration of ReporterThread will take care of allocating. - numExecutorsRunning.decrementAndGet() - logInfo("Completed container %s (state: %s, exit status: %s)".format( - containerId, - completedContainer.getState, - completedContainer.getExitStatus())) - // Hadoop 2.2.X added a ContainerExitStatus we should switch to use - // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus() != 0) { - logInfo("Container marked as failed: " + containerId) - numExecutorsFailed.incrementAndGet() - } - } - - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).getOrElse(null) - assert (host != null) - - val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null) - assert (containerSet != null) - - containerSet -= containerId - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap -= containerId - - // Doing this within locked context, sigh ... move to outside ? - val rack = YarnAllocationHandler.lookupRack(conf, host) - if (rack != null) { - val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 - if (rackCount > 0) { - allocatedRackCount.put(rack, rackCount) - } else { - allocatedRackCount.remove(rack) - } - } - } - } - } - logDebug(""" - Finished processing %d completed containers. - Current number of executors running: %d, - releasedContainerList: %s, - pendingReleaseContainers: %s - """.format( - completedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers)) - } - } - - def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = { - // First generate modified racks and new set of hosts under it : then issue requests - val rackToCounts = new HashMap[String, Int]() - - // Within this lock - used to read/write to the rack related maps too. - for (container <- hostContainers) { - val candidateHost = container.getHostName - val candidateNumContainers = container.getNumContainers - assert(YarnAllocationHandler.ANY_HOST != candidateHost) - - val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) - if (rack != null) { - var count = rackToCounts.getOrElse(rack, 0) - count += candidateNumContainers - rackToCounts.put(rack, count) - } - } - - val requestedContainers: ArrayBuffer[ResourceRequest] = - new ArrayBuffer[ResourceRequest](rackToCounts.size) - for ((rack, count) <- rackToCounts){ - requestedContainers += - createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY) - } - - requestedContainers.toList - } - - def allocatedContainersOnHost(host: String): Int = { - var retval = 0 - allocatedHostToContainersMap.synchronized { - retval = allocatedHostToContainersMap.getOrElse(host, Set()).size - } - retval - } - - def allocatedContainersOnRack(rack: String): Int = { - var retval = 0 - allocatedHostToContainersMap.synchronized { - retval = allocatedRackCount.getOrElse(rack, 0) - } - retval - } + conf: Configuration, + sparkConf: SparkConf, + resourceManager: AMRMProtocol, + appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + preferredNodes: collection.Map[String, collection.Set[SplitInfo]], + securityMgr: SecurityManager) + extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { - private def allocateExecutorResources(numExecutors: Int): AllocateResponse = { + private val lastResponseId = new AtomicInteger() + private val releaseList: CopyOnWriteArrayList[ContainerId] = new CopyOnWriteArrayList() + override protected def allocateContainers(count: Int): YarnAllocateResponse = { var resourceRequests: List[ResourceRequest] = null - // default. - if (numExecutors <= 0 || preferredHostToCount.isEmpty) { - logDebug("numExecutors: " + numExecutors + ", host preferences: " + - preferredHostToCount.isEmpty) - resourceRequests = List(createResourceRequest( - AllocationType.ANY, null, numExecutors, YarnAllocationHandler.PRIORITY)) + logDebug("numExecutors: " + count) + if (count <= 0) { + resourceRequests = List() + } else if (preferredHostToCount.isEmpty) { + logDebug("host preferences is empty") + resourceRequests = List(createResourceRequest( + AllocationType.ANY, null, count, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)) } else { // request for all hosts in preferred nodes and for numExecutors - // candidates.size, request by default allocation policy. @@ -421,7 +71,7 @@ private[yarn] class YarnAllocationHandler( AllocationType.HOST, candidateHost, requiredCount, - YarnAllocationHandler.PRIORITY) + YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) } } val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests( @@ -430,8 +80,8 @@ private[yarn] class YarnAllocationHandler( val anyContainerRequests: ResourceRequest = createResourceRequest( AllocationType.ANY, resource = null, - numExecutors, - YarnAllocationHandler.PRIORITY) + count, + YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) val containerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest]( hostContainerRequests.size + rackContainerRequests.size + 1) @@ -452,8 +102,8 @@ private[yarn] class YarnAllocationHandler( val releasedContainerList = createReleasedContainerList() req.addAllReleases(releasedContainerList) - if (numExecutors > 0) { - logInfo("Allocating %d executor containers with %d of memory each.".format(numExecutors, + if (count > 0) { + logInfo("Allocating %d executor containers with %d of memory each.".format(count, executorMemory + memoryOverhead)) } else { logDebug("Empty allocation req .. release : " + releasedContainerList) @@ -467,9 +117,42 @@ private[yarn] class YarnAllocationHandler( request.getPriority, request.getCapability)) } - resourceManager.allocate(req) + new AlphaAllocateResponse(resourceManager.allocate(req).getAMResponse()) + } + + override protected def releaseContainer(container: Container) = { + releaseList.add(container.getId()) } + private def createRackResourceRequests(hostContainers: List[ResourceRequest]): + List[ResourceRequest] = { + // First generate modified racks and new set of hosts under it : then issue requests + val rackToCounts = new HashMap[String, Int]() + + // Within this lock - used to read/write to the rack related maps too. + for (container <- hostContainers) { + val candidateHost = container.getHostName + val candidateNumContainers = container.getNumContainers + assert(YarnSparkHadoopUtil.ANY_HOST != candidateHost) + + val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) + if (rack != null) { + var count = rackToCounts.getOrElse(rack, 0) + count += candidateNumContainers + rackToCounts.put(rack, count) + } + } + + val requestedContainers: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts){ + requestedContainers += + createResourceRequest(AllocationType.RACK, rack, count, + YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) + } + + requestedContainers.toList + } private def createResourceRequest( requestType: AllocationType.AllocationType, @@ -481,12 +164,12 @@ private[yarn] class YarnAllocationHandler( // There must be a third request - which is ANY : that will be specially handled. requestType match { case AllocationType.HOST => { - assert(YarnAllocationHandler.ANY_HOST != resource) + assert(YarnSparkHadoopUtil.ANY_HOST != resource) val hostname = resource val nodeLocal = createResourceRequestImpl(hostname, numExecutors, priority) // Add to host->rack mapping - YarnAllocationHandler.populateRackInfo(conf, hostname) + YarnSparkHadoopUtil.populateRackInfo(conf, hostname) nodeLocal } @@ -495,7 +178,7 @@ private[yarn] class YarnAllocationHandler( createResourceRequestImpl(rack, numExecutors, priority) } case AllocationType.ANY => createResourceRequestImpl( - YarnAllocationHandler.ANY_HOST, numExecutors, priority) + YarnSparkHadoopUtil.ANY_HOST, numExecutors, priority) case _ => throw new IllegalArgumentException( "Unexpected/unsupported request type: " + requestType) } @@ -522,169 +205,24 @@ private[yarn] class YarnAllocationHandler( rsrcRequest } - def createReleasedContainerList(): ArrayBuffer[ContainerId] = { - + private def createReleasedContainerList(): ArrayBuffer[ContainerId] = { val retval = new ArrayBuffer[ContainerId](1) // Iterator on COW list ... - for (container <- releasedContainerList.iterator()){ + for (container <- releaseList.iterator()){ retval += container } // Remove from the original list. - if (! retval.isEmpty) { - releasedContainerList.removeAll(retval) - for (v <- retval) pendingReleaseContainers.put(v, true) - logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " + - pendingReleaseContainers) + if (!retval.isEmpty) { + releaseList.removeAll(retval) + logInfo("Releasing " + retval.size + " containers.") } - retval } -} - -object YarnAllocationHandler { - - val ANY_HOST = "*" - // All requests are issued with same priority : we do not (yet) have any distinction between - // request types (like map/reduce in hadoop for example) - val PRIORITY = 1 - - // Additional memory overhead - in mb - val MEMORY_OVERHEAD = 384 - - // Host to rack map - saved from allocation requests - // We are expecting this not to change. - // Note that it is possible for this to change : and RM will indicate that to us via update - // response to allocate. But we are punting on handling that for now. - private val hostToRack = new ConcurrentHashMap[String, String]() - private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() - - def newAllocator( - conf: Configuration, - resourceManager: AMRMProtocol, - appAttemptId: ApplicationAttemptId, - args: ApplicationMasterArguments, - sparkConf: SparkConf): YarnAllocationHandler = { - - new YarnAllocationHandler( - conf, - resourceManager, - appAttemptId, - args.numExecutors, - args.executorMemory, - args.executorCores, - Map[String, Int](), - Map[String, Int](), - sparkConf) + private class AlphaAllocateResponse(response: AMResponse) extends YarnAllocateResponse { + override def getAllocatedContainers() = response.getAllocatedContainers() + override def getAvailableResources() = response.getAvailableResources() + override def getCompletedContainersStatuses() = response.getCompletedContainersStatuses() } - def newAllocator( - conf: Configuration, - resourceManager: AMRMProtocol, - appAttemptId: ApplicationAttemptId, - args: ApplicationMasterArguments, - map: collection.Map[String, - collection.Set[SplitInfo]], - sparkConf: SparkConf): YarnAllocationHandler = { - - val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) - new YarnAllocationHandler( - conf, - resourceManager, - appAttemptId, - args.numExecutors, - args.executorMemory, - args.executorCores, - hostToCount, - rackToCount, - sparkConf) - } - - def newAllocator( - conf: Configuration, - resourceManager: AMRMProtocol, - appAttemptId: ApplicationAttemptId, - maxExecutors: Int, - executorMemory: Int, - executorCores: Int, - map: collection.Map[String, collection.Set[SplitInfo]], - sparkConf: SparkConf): YarnAllocationHandler = { - - val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) - - new YarnAllocationHandler( - conf, - resourceManager, - appAttemptId, - maxExecutors, - executorMemory, - executorCores, - hostToCount, - rackToCount, - sparkConf) - } - - // A simple method to copy the split info map. - private def generateNodeToWeight( - conf: Configuration, - input: collection.Map[String, collection.Set[SplitInfo]]) : - // host to count, rack to count - (Map[String, Int], Map[String, Int]) = { - - if (input == null) return (Map[String, Int](), Map[String, Int]()) - - val hostToCount = new HashMap[String, Int] - val rackToCount = new HashMap[String, Int] - - for ((host, splits) <- input) { - val hostCount = hostToCount.getOrElse(host, 0) - hostToCount.put(host, hostCount + splits.size) - - val rack = lookupRack(conf, host) - if (rack != null){ - val rackCount = rackToCount.getOrElse(host, 0) - rackToCount.put(host, rackCount + splits.size) - } - } - - (hostToCount.toMap, rackToCount.toMap) - } - - def lookupRack(conf: Configuration, host: String): String = { - if (!hostToRack.contains(host)) populateRackInfo(conf, host) - hostToRack.get(host) - } - - def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { - val set = rackToHostSet.get(rack) - if (set == null) return None - - // No better way to get a Set[String] from JSet ? - val convertedSet: collection.mutable.Set[String] = set - Some(convertedSet.toSet) - } - - def populateRackInfo(conf: Configuration, hostname: String) { - Utils.checkHost(hostname) - - if (!hostToRack.containsKey(hostname)) { - // If there are repeated failures to resolve, all to an ignore list ? - val rackInfo = RackResolver.resolve(conf, hostname) - if (rackInfo != null && rackInfo.getNetworkLocation != null) { - val rack = rackInfo.getNetworkLocation - hostToRack.put(hostname, rack) - if (! rackToHostSet.containsKey(rack)) { - rackToHostSet.putIfAbsent(rack, - Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) - } - rackToHostSet.get(rack).add(hostname) - - // TODO(harvey): Figure out this comment... - // Since RackResolver caches, we are disabling this for now ... - } /* else { - // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... - hostToRack.put(hostname, null) - } */ - } - } } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala new file mode 100644 index 0000000000000..acf26505e4cf9 --- /dev/null +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import scala.collection.{Map, Set} +import java.net.URI; + +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.scheduler.SplitInfo +import org.apache.spark.util.Utils + +/** + * YarnRMClient implementation for the Yarn alpha API. + */ +private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMClient with Logging { + + private var rpc: YarnRPC = null + private var resourceManager: AMRMProtocol = _ + private var uiHistoryAddress: String = _ + + override def register( + conf: YarnConfiguration, + sparkConf: SparkConf, + preferredNodeLocations: Map[String, Set[SplitInfo]], + uiAddress: String, + uiHistoryAddress: String, + securityMgr: SecurityManager) = { + this.rpc = YarnRPC.create(conf) + this.uiHistoryAddress = uiHistoryAddress + + resourceManager = registerWithResourceManager(conf) + registerApplicationMaster(uiAddress) + + new YarnAllocationHandler(conf, sparkConf, resourceManager, getAttemptId(), args, + preferredNodeLocations, securityMgr) + } + + override def getAttemptId() = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + appAttemptId + } + + override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") = { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(getAttemptId()) + finishReq.setFinishApplicationStatus(status) + finishReq.setDiagnostics(diagnostics) + finishReq.setTrackingUrl(uiHistoryAddress) + resourceManager.finishApplicationMaster(finishReq) + } + + override def getProxyHostAndPort(conf: YarnConfiguration) = + YarnConfiguration.getProxyHostAndPort(conf) + + override def getMaxRegAttempts(conf: YarnConfiguration) = + conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) + + private def registerWithResourceManager(conf: YarnConfiguration): AMRMProtocol = { + val rmAddress = NetUtils.createSocketAddr(conf.get(YarnConfiguration.RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) + logInfo("Connecting to ResourceManager at " + rmAddress) + rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] + } + + private def registerApplicationMaster(uiAddress: String): RegisterApplicationMasterResponse = { + val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) + .asInstanceOf[RegisterApplicationMasterRequest] + appMasterRequest.setApplicationAttemptId(getAttemptId()) + // Setting this to master host,port - so that the ApplicationReport at client has some + // sensible info. + // Users can then monitor stderr/stdout on that node if required. + appMasterRequest.setHost(Utils.localHostName()) + appMasterRequest.setRpcPort(0) + // remove the scheme from the url if it exists since Hadoop does not expect scheme + appMasterRequest.setTrackingUrl(new URI(uiAddress).getAuthority()) + resourceManager.registerApplicationMaster(appMasterRequest) + } + +} diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala new file mode 100644 index 0000000000000..878b6db546032 --- /dev/null +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.IOException +import java.net.Socket +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.JavaConversions._ +import scala.util.Try + +import akka.actor._ +import akka.remote._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.util.ShutdownHookManager +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.HistoryServer +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter +import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} + +/** + * Common application master functionality for Spark on Yarn. + */ +private[spark] class ApplicationMaster(args: ApplicationMasterArguments, + client: YarnRMClient) extends Logging { + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be + // optimal as more containers are available. Might need to handle this better. + + private val sparkConf = new SparkConf() + private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf) + .asInstanceOf[YarnConfiguration] + private val isDriver = args.userClass != null + + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + + @volatile private var finished = false + @volatile private var finalStatus = FinalApplicationStatus.UNDEFINED + + private var reporterThread: Thread = _ + private var allocator: YarnAllocator = _ + + // Fields used in client mode. + private var actorSystem: ActorSystem = null + private var actor: ActorRef = _ + + // Fields used in cluster mode. + private val sparkContextRef = new AtomicReference[SparkContext](null) + + final def run(): Int = { + val appAttemptId = client.getAttemptId() + + if (isDriver) { + // Set the web ui port to be ephemeral for yarn so we don't conflict with + // other spark processes running on the same box + System.setProperty("spark.ui.port", "0") + + // Set the master property to match the requested mode. + System.setProperty("spark.master", "yarn-cluster") + + // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. + System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) + } + + logInfo("ApplicationAttemptId: " + appAttemptId) + + val cleanupHook = new Runnable { + override def run() { + // If the SparkContext is still registered, shut it down as a best case effort in case + // users do not call sc.stop or do System.exit(). + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + finish(FinalApplicationStatus.SUCCEEDED) + } + + // Cleanup the staging dir after the app is finished, or if it's the last attempt at + // running the AM. + val maxAppAttempts = client.getMaxRegAttempts(yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + if (finished || isLastAttempt) { + cleanupStagingDir() + } + } + } + // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using. + ShutdownHookManager.get().addShutdownHook(cleanupHook, 30) + + // Call this to force generation of secret so it gets populated into the + // Hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the executor containers. + val securityMgr = new SecurityManager(sparkConf) + + if (isDriver) { + runDriver(securityMgr) + } else { + runExecutorLauncher(securityMgr) + } + + if (finalStatus != FinalApplicationStatus.UNDEFINED) { + finish(finalStatus) + 0 + } else { + 1 + } + } + + final def finish(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { + if (!finished) { + logInfo(s"Finishing ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + finished = true + finalStatus = status + try { + if (Thread.currentThread() != reporterThread) { + reporterThread.interrupt() + reporterThread.join() + } + } finally { + client.shutdown(status, Option(diagnostics).getOrElse("")) + } + } + } + + private def sparkContextInitialized(sc: SparkContext) = { + sparkContextRef.synchronized { + sparkContextRef.compareAndSet(null, sc) + sparkContextRef.notifyAll() + } + } + + private def sparkContextStopped(sc: SparkContext) = { + sparkContextRef.compareAndSet(sc, null) + } + + private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { + val sc = sparkContextRef.get() + + val appId = client.getAttemptId().getApplicationId().toString() + val historyAddress = + sparkConf.getOption("spark.yarn.historyServer.address") + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } + .getOrElse("") + + allocator = client.register(yarnConf, + if (sc != null) sc.getConf else sparkConf, + if (sc != null) sc.preferredNodeLocationData else Map(), + uiAddress, + historyAddress, + securityMgr) + + allocator.allocateResources() + reporterThread = launchReporterThread() + } + + private def runDriver(securityMgr: SecurityManager): Unit = { + addAmIpFilter() + val userThread = startUserClass() + + // This a bit hacky, but we need to wait until the spark.driver.port property has + // been set by the Thread executing the user class. + val sc = waitForSparkContextInitialized() + + // If there is no SparkContext at this point, just fail the app. + if (sc == null) { + finish(FinalApplicationStatus.FAILED, "Timed out waiting for SparkContext.") + } else { + registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + try { + userThread.join() + } finally { + // In cluster mode, ask the reporter thread to stop since the user app is finished. + reporterThread.interrupt() + } + } + } + + private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { + actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, + conf = sparkConf, securityManager = securityMgr)._1 + actor = waitForSparkDriver() + addAmIpFilter() + registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + + // In client mode the actor will stop the reporter thread. + reporterThread.join() + finalStatus = FinalApplicationStatus.SUCCEEDED + } + + private def launchReporterThread(): Thread = { + // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. + val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) + + // must be <= expiryInterval / 2. + val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) + + val t = new Thread { + override def run() { + while (!finished) { + checkNumExecutorsFailed() + if (!finished) { + logDebug("Sending progress") + allocator.allocateResources() + try { + Thread.sleep(interval) + } catch { + case e: InterruptedException => + } + } + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.setName("Reporter") + t.start() + logInfo("Started progress reporter thread - sleep time : " + interval) + t + } + + /** + * Clean up the staging directory. + */ + private def cleanupStagingDir() { + val fs = FileSystem.get(yarnConf) + var stagingDirPath: Path = null + try { + val preserveFiles = sparkConf.get("spark.yarn.preserve.staging.files", "false").toBoolean + if (!preserveFiles) { + stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) + if (stagingDirPath == null) { + logError("Staging directory is null") + return + } + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logError("Failed to cleanup staging dir " + stagingDirPath, ioe) + } + } + + private def waitForSparkContextInitialized(): SparkContext = { + logInfo("Waiting for spark context initialization") + try { + sparkContextRef.synchronized { + var count = 0 + val waitTime = 10000L + val numTries = sparkConf.getInt("spark.yarn.ApplicationMaster.waitTries", 10) + while (sparkContextRef.get() == null && count < numTries && !finished) { + logInfo("Waiting for spark context initialization ... " + count) + count = count + 1 + sparkContextRef.wait(waitTime) + } + + val sparkContext = sparkContextRef.get() + assert(sparkContext != null || count >= numTries) + if (sparkContext == null) { + logError( + "Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d".format( + count * waitTime, numTries)) + } + sparkContext + } + } + } + + private def waitForSparkDriver(): ActorRef = { + logInfo("Waiting for Spark driver to be reachable.") + var driverUp = false + val hostport = args.userArgs(0) + val (driverHost, driverPort) = Utils.parseHostPort(hostport) + while (!driverUp) { + try { + val socket = new Socket(driverHost, driverPort) + socket.close() + logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) + driverUp = true + } catch { + case e: Exception => + logError("Failed to connect to driver at %s:%s, retrying ...". + format(driverHost, driverPort)) + Thread.sleep(100) + } + } + sparkConf.set("spark.driver.host", driverHost) + sparkConf.set("spark.driver.port", driverPort.toString) + + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, + driverHost, + driverPort.toString, + CoarseGrainedSchedulerBackend.ACTOR_NAME) + actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") + } + + private def checkNumExecutorsFailed() = { + if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finish(FinalApplicationStatus.FAILED, "Max number of executor failures reached.") + + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from checkNumExecutorsFailed") + sc.stop() + } + } + } + + /** Add the Yarn IP filter that is required for properly securing the UI. */ + private def addAmIpFilter() = { + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val proxy = client.getProxyHostAndPort(yarnConf) + val parts = proxy.split(":") + val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + val uriBase = "http://" + proxy + proxyBase + val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + + if (isDriver) { + System.setProperty("spark.ui.filters", amFilter) + System.setProperty(s"spark.$amFilter.params", params) + } else { + actor ! AddWebUIFilter(amFilter, params, proxyBase) + } + } + + private def startUserClass(): Thread = { + logInfo("Starting the user JAR in a separate Thread") + System.setProperty("spark.executor.instances", args.numExecutors.toString) + val mainMethod = Class.forName(args.userClass, false, + Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) + + val t = new Thread { + override def run() { + var status = FinalApplicationStatus.FAILED + try { + // Copy + val mainArgs = new Array[String](args.userArgs.size) + args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) + mainMethod.invoke(null, mainArgs) + // Some apps have "System.exit(0)" at the end. The user thread will stop here unless + // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. + status = FinalApplicationStatus.SUCCEEDED + } finally { + logDebug("Finishing main") + } + finalStatus = status + } + } + t.setName("Driver") + t.start() + t + } + + // Actor used to monitor the driver when running in client deploy mode. + private class MonitorActor(driverUrl: String) extends Actor { + + var driver: ActorSelection = _ + + override def preStart() = { + logInfo("Listen to driver: " + driverUrl) + driver = context.actorSelection(driverUrl) + // Send a hello message to establish the connection, after which + // we can monitor Lifecycle Events. + driver ! "Hello" + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive = { + case x: DisassociatedEvent => + logInfo(s"Driver terminated or disconnected! Shutting down. $x") + finish(FinalApplicationStatus.SUCCEEDED) + case x: AddWebUIFilter => + logInfo(s"Add WebUI Filter. $x") + driver ! x + } + + } + +} + +object ApplicationMaster extends Logging { + + private var master: ApplicationMaster = _ + + def main(args: Array[String]) = { + SignalLogger.register(log) + val amArgs = new ApplicationMasterArguments(args) + SparkHadoopUtil.get.runAsSparkUser { () => + master = new ApplicationMaster(amArgs, new YarnRMClientImpl(amArgs)) + System.exit(master.run()) + } + } + + private[spark] def sparkContextInitialized(sc: SparkContext) = { + master.sparkContextInitialized(sc) + } + + private[spark] def sparkContextStopped(sc: SparkContext) = { + master.sparkContextStopped(sc) + } + +} + +/** + * This object does not provide any special functionality. It exists so that it's easy to tell + * apart the client-mode AM from the cluster-mode AM when using tools such as ps or jps. + */ +object ExecutorLauncher { + + def main(args: Array[String]) = { + ApplicationMaster.main(args) + } + +} diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 424b0fb0936f2..3e6b96fb63cea 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -63,11 +63,6 @@ class ApplicationMasterArguments(val args: Array[String]) { executorCores = value args = tail - case Nil => - if (userJar == null || userClass == null) { - printUsageAndExit(1) - } - case _ => printUsageAndExit(1, args) } @@ -80,16 +75,17 @@ class ApplicationMasterArguments(val args: Array[String]) { if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) } - System.err.println( - "Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options] \n" + - "Options:\n" + - " --jar JAR_PATH Path to your application's JAR file (required)\n" + - " --class CLASS_NAME Name of your application's main class (required)\n" + - " --args ARGS Arguments to be passed to your application's main class.\n" + - " Mutliple invocations are possible, each will be passed in order.\n" + - " --num-executors NUM Number of executors to start (Default: 2)\n" + - " --executor-cores NUM Number of cores for the executors (Default: 1)\n" + - " --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G)\n") + System.err.println(""" + |Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options] + |Options: + | --jar JAR_PATH Path to your application's JAR file + | --class CLASS_NAME Name of your application's main class + | --args ARGS Arguments to be passed to your application's main class. + | Mutliple invocations are possible, each will be passed in order. + | --num-executors NUM Number of executors to start (Default: 2) + | --executor-cores NUM Number of cores for the executors (Default: 1) + | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) + """.stripMargin) System.exit(exitCode) } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 62f9b3cf5ab88..40d8d6d6e6961 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -37,9 +37,7 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { var numExecutors = 2 var amQueue = sparkConf.get("QUEUE", "default") var amMemory: Int = 512 // MB - var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster" var appName: String = "Spark" - var inputFormatInfo: List[InputFormatInfo] = null var priority = 0 parseArgs(args.toList) @@ -58,8 +56,7 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { private def parseArgs(inputArgs: List[String]): Unit = { val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() - val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() - + var args = inputArgs while (!args.isEmpty) { @@ -80,10 +77,7 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { args = tail case ("--master-class" | "--am-class") :: value :: tail => - if (args(0) == "--master-class") { - println("--master-class is deprecated. Use --am-class instead.") - } - amClass = value + println(s"${args(0)} is deprecated and is not used anymore.") args = tail case ("--master-memory" | "--driver-memory") :: MemoryParam(value) :: tail => @@ -135,9 +129,6 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { args = tail case Nil => - if (userClass == null) { - throw new IllegalArgumentException(getUsageMessage()) - } case _ => throw new IllegalArgumentException(getUsageMessage(args)) @@ -145,7 +136,6 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { } userArgs = userArgsBuffer.readOnly - inputFormatInfo = inputFormatMap.values.toList } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 3897b3a373a8c..c96f731923d22 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -42,12 +42,6 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, Spar /** * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The * Client submits an application to the YARN ResourceManager. - * - * Depending on the deployment mode this will launch one of two application master classes: - * 1. In cluster mode, it will launch an [[org.apache.spark.deploy.yarn.ApplicationMaster]] - * which launches a driver program inside of the cluster. - * 2. In client mode, it will launch an [[org.apache.spark.deploy.yarn.ExecutorLauncher]] to - * request executors on behalf of a driver running outside of the cluster. */ trait ClientBase extends Logging { val args: ClientArguments @@ -67,14 +61,11 @@ trait ClientBase extends Logging { // Additional memory overhead - in mb. protected def memoryOverhead: Int = sparkConf.getInt("spark.yarn.driver.memoryOverhead", - YarnAllocationHandler.MEMORY_OVERHEAD) + YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) // TODO(harvey): This could just go in ClientArguments. def validateArgs() = { Map( - ((args.userJar == null && args.amClass == classOf[ApplicationMaster].getName) -> - "Error: You must specify a user jar when running in standalone mode!"), - (args.userClass == null) -> "Error: You must specify a user class!", (args.numExecutors <= 0) -> "Error: You must specify at least 1 executor!", (args.amMemory <= memoryOverhead) -> ("Error: AM memory size must be" + "greater than: " + memoryOverhead), @@ -218,7 +209,7 @@ trait ClientBase extends Logging { if (! localPath.isEmpty()) { val localURI = new URI(localPath) if (!ClientBase.LOCAL_SCHEME.equals(localURI.getScheme())) { - val setPermissions = if (destName.equals(ClientBase.APP_JAR)) true else false + val setPermissions = destName.equals(ClientBase.APP_JAR) val destPath = copyRemoteFile(dst, qualifyForLocal(localURI), replication, setPermissions) val destFs = FileSystem.get(destPath.toUri(), conf) distCacheMgr.addResource(destFs, conf, destPath, localResources, LocalResourceType.FILE, @@ -309,8 +300,6 @@ trait ClientBase extends Logging { retval.toString } - def calculateAMMemory(newApp: GetNewApplicationResponse): Int - def setupSecurityToken(amContainer: ContainerLaunchContext) def createContainerLaunchContext( @@ -321,6 +310,8 @@ trait ClientBase extends Logging { val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) amContainer.setLocalResources(localResources) + val isLaunchingDriver = args.userClass != null + // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's // SparkContext will not let that set spark* system properties, which is expected behavior for @@ -329,7 +320,7 @@ trait ClientBase extends Logging { // Note that to warn the user about the deprecation in cluster mode, some code from // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition // described above). - if (args.amClass == classOf[ApplicationMaster].getName) { + if (isLaunchingDriver) { sys.env.get("SPARK_JAVA_OPTS").foreach { value => val warning = s""" @@ -353,7 +344,7 @@ trait ClientBase extends Logging { } amContainer.setEnvironment(env) - val amMemory = calculateAMMemory(newApp) + val amMemory = args.amMemory val javaOpts = ListBuffer[String]() @@ -389,7 +380,7 @@ trait ClientBase extends Logging { javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } - if (args.amClass == classOf[ApplicationMaster].getName) { + if (isLaunchingDriver) { sparkConf.getOption("spark.driver.extraJavaOptions") .orElse(sys.env.get("SPARK_JAVA_OPTS")) .foreach(opts => javaOpts += opts) @@ -397,22 +388,37 @@ trait ClientBase extends Logging { .foreach(p => javaOpts += s"-Djava.library.path=$p") } - // Command for the ApplicationMaster - val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ - javaOpts ++ - Seq(args.amClass, "--class", YarnSparkHadoopUtil.escapeForShell(args.userClass), - "--jar ", YarnSparkHadoopUtil.escapeForShell(args.userJar), - userArgsToString(args), - "--executor-memory", args.executorMemory.toString, + val userClass = + if (args.userClass != null) { + Seq("--class", YarnSparkHadoopUtil.escapeForShell(args.userClass)) + } else { + Nil + } + val amClass = + if (isLaunchingDriver) { + classOf[ApplicationMaster].getName() + } else { + classOf[ApplicationMaster].getName().replace("ApplicationMaster", "ExecutorLauncher") + } + val amArgs = + Seq(amClass) ++ userClass ++ + (if (args.userJar != null) Seq("--jar", args.userJar) else Nil) ++ + Seq("--executor-memory", args.executorMemory.toString, "--executor-cores", args.executorCores.toString, "--num-executors ", args.numExecutors.toString, + userArgsToString(args)) + + // Command for the ApplicationMaster + val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ + javaOpts ++ amArgs ++ + Seq( "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") logInfo("Yarn AM launch context:") - logInfo(s" class: ${args.amClass}") - logInfo(s" env: $env") - logInfo(s" command: ${commands.mkString(" ")}") + logInfo(s" user class: ${args.userClass}") + logInfo(s" env: $env") + logInfo(s" command: ${commands.mkString(" ")}") // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList @@ -422,10 +428,8 @@ trait ClientBase extends Logging { // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) - val acls = Map[ApplicationAccessType, String] ( - ApplicationAccessType.VIEW_APP -> securityManager.getViewAcls, - ApplicationAccessType.MODIFY_APP -> securityManager.getModifyAcls) - amContainer.setApplicationACLs(acls) + amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + amContainer } } @@ -623,7 +627,7 @@ object ClientBase extends Logging { YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, File.pathSeparator) - /** + /** * Get the list of namenodes the user may access. */ private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala new file mode 100644 index 0000000000000..0b8744f4b8bdf --- /dev/null +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.util.{List => JList} +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse + +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +import com.google.common.util.concurrent.ThreadFactoryBuilder + +object AllocationType extends Enumeration { + type AllocationType = Value + val HOST, RACK, ANY = Value +} + +// TODO: +// Too many params. +// Needs to be mt-safe +// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should +// make it more proactive and decoupled. + +// Note that right now, we assume all node asks as uniform in terms of capabilities and priority +// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for +// more info on how we are requesting for containers. + +/** + * Common code for the Yarn container allocator. Contains all the version-agnostic code to + * manage container allocation for a running Spark application. + */ +private[yarn] abstract class YarnAllocator( + conf: Configuration, + sparkConf: SparkConf, + args: ApplicationMasterArguments, + preferredNodes: collection.Map[String, collection.Set[SplitInfo]], + securityMgr: SecurityManager) + extends Logging { + + // These three are locked on allocatedHostToContainersMap. Complementary data structures + // allocatedHostToContainersMap : containers which are running : host, Set + // allocatedContainerToHostMap: container to host mapping. + private val allocatedHostToContainersMap = + new HashMap[String, collection.mutable.Set[ContainerId]]() + + private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + + // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an + // allocated node) + // As with the two data structures above, tightly coupled with them, and to be locked on + // allocatedHostToContainersMap + private val allocatedRackCount = new HashMap[String, Int]() + + // Containers to be released in next request to RM + private val releasedContainers = new ConcurrentHashMap[ContainerId, Boolean] + + // Additional memory overhead - in mb. + protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) + + // Number of container requests that have been sent to, but not yet allocated by the + // ApplicationMaster. + private val numPendingAllocate = new AtomicInteger() + private val numExecutorsRunning = new AtomicInteger() + // Used to generate a unique id per executor + private val executorIdCounter = new AtomicInteger() + private val numExecutorsFailed = new AtomicInteger() + + private val maxExecutors = args.numExecutors + + protected val executorMemory = args.executorMemory + protected val executorCores = args.executorCores + protected val (preferredHostToCount, preferredRackToCount) = + generateNodeToWeight(conf, preferredNodes) + + private val launcherPool = new ThreadPoolExecutor( + // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue + sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE, + 1, TimeUnit.MINUTES, + new LinkedBlockingQueue[Runnable](), + new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) + launcherPool.allowCoreThreadTimeOut(true) + + def getNumExecutorsRunning: Int = numExecutorsRunning.intValue + + def getNumExecutorsFailed: Int = numExecutorsFailed.intValue + + def allocateResources() = { + val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get() + + if (missing > 0) { + numPendingAllocate.addAndGet(missing) + logInfo("Will Allocate %d executor containers, each with %d memory".format( + missing, + (executorMemory + memoryOverhead))) + } else { + logDebug("Empty allocation request ...") + } + + val allocateResponse = allocateContainers(missing) + val allocatedContainers = allocateResponse.getAllocatedContainers() + + if (allocatedContainers.size > 0) { + var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) + + if (numPendingAllocateNow < 0) { + numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) + } + + logDebug(""" + Allocated containers: %d + Current executor count: %d + Containers released: %s + Cluster resources: %s + """.format( + allocatedContainers.size, + numExecutorsRunning.get(), + releasedContainers, + allocateResponse.getAvailableResources)) + + val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (container <- allocatedContainers) { + if (isResourceConstraintSatisfied(container)) { + // Add the accepted `container` to the host's list of already accepted, + // allocated containers + val host = container.getNodeId.getHost + val containersForHost = hostToContainers.getOrElseUpdate(host, + new ArrayBuffer[Container]()) + containersForHost += container + } else { + // Release container, since it doesn't satisfy resource constraints. + internalReleaseContainer(container) + } + } + + // Find the appropriate containers to use. + // TODO: Cleanup this group-by... + val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (candidateHost <- hostToContainers.keySet) { + val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) + val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) + + val remainingContainersOpt = hostToContainers.get(candidateHost) + assert(remainingContainersOpt.isDefined) + var remainingContainers = remainingContainersOpt.get + + if (requiredHostCount >= remainingContainers.size) { + // Since we have <= required containers, add all remaining containers to + // `dataLocalContainers`. + dataLocalContainers.put(candidateHost, remainingContainers) + // There are no more free containers remaining. + remainingContainers = null + } else if (requiredHostCount > 0) { + // Container list has more containers than we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. + val (dataLocal, remaining) = remainingContainers.splitAt( + remainingContainers.size - requiredHostCount) + dataLocalContainers.put(candidateHost, dataLocal) + + // Invariant: remainingContainers == remaining + + // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. + // Add each container in `remaining` to list of containers to release. If we have an + // insufficient number of containers, then the next allocation cycle will reallocate + // (but won't treat it as data local). + // TODO(harvey): Rephrase this comment some more. + for (container <- remaining) internalReleaseContainer(container) + remainingContainers = null + } + + // For rack local containers + if (remainingContainers != null) { + val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) + if (rack != null) { + val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.getOrElse(rack, List()).size + + if (requiredRackCount >= remainingContainers.size) { + // Add all remaining containers to to `dataLocalContainers`. + dataLocalContainers.put(rack, remainingContainers) + remainingContainers = null + } else if (requiredRackCount > 0) { + // Container list has more containers that we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. + val (rackLocal, remaining) = remainingContainers.splitAt( + remainingContainers.size - requiredRackCount) + val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, + new ArrayBuffer[Container]()) + + existingRackLocal ++= rackLocal + + remainingContainers = remaining + } + } + } + + if (remainingContainers != null) { + // Not all containers have been consumed - add them to the list of off-rack containers. + offRackContainers.put(candidateHost, remainingContainers) + } + } + + // Now that we have split the containers into various groups, go through them in order: + // first host-local, then rack-local, and finally off-rack. + // Note that the list we create below tries to ensure that not all containers end up within + // a host if there is a sufficiently large number of hosts/containers. + val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers. + for (container <- allocatedContainersToProcess) { + val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() + val executorHostname = container.getNodeId.getHost + val containerId = container.getId + + val executorMemoryOverhead = (executorMemory + memoryOverhead) + assert(container.getResource.getMemory >= executorMemoryOverhead) + + if (numExecutorsRunningNow > maxExecutors) { + logInfo("""Ignoring container %s at host %s, since we already have the required number of + containers for it.""".format(containerId, executorHostname)) + internalReleaseContainer(container) + numExecutorsRunning.decrementAndGet() + } else { + val executorId = executorIdCounter.incrementAndGet().toString + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, + sparkConf.get("spark.driver.host"), + sparkConf.get("spark.driver.port"), + CoarseGrainedSchedulerBackend.ACTOR_NAME) + + logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) + + // To be safe, remove the container from `releasedContainers`. + releasedContainers.remove(containerId) + + val rack = YarnSparkHadoopUtil.lookupRack(conf, executorHostname) + allocatedHostToContainersMap.synchronized { + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, + new HashSet[ContainerId]()) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, executorHostname) + + if (rack != null) { + allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) + } + } + logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( + driverUrl, executorHostname)) + val executorRunnable = new ExecutorRunnable( + container, + conf, + sparkConf, + driverUrl, + executorId, + executorHostname, + executorMemory, + executorCores, + securityMgr) + launcherPool.execute(executorRunnable) + } + } + logDebug(""" + Finished allocating %s containers (from %s originally). + Current number of executors running: %d, + Released containers: %s + """.format( + allocatedContainersToProcess, + allocatedContainers, + numExecutorsRunning.get(), + releasedContainers)) + } + + val completedContainers = allocateResponse.getCompletedContainersStatuses() + if (completedContainers.size > 0) { + logDebug("Completed %d containers".format(completedContainers.size)) + + for (completedContainer <- completedContainers) { + val containerId = completedContainer.getContainerId + + if (releasedContainers.containsKey(containerId)) { + // YarnAllocationHandler already marked the container for release, so remove it from + // `releasedContainers`. + releasedContainers.remove(containerId) + } else { + // Decrement the number of executors running. The next iteration of + // the ApplicationMaster's reporting thread will take care of allocating. + numExecutorsRunning.decrementAndGet() + logInfo("Completed container %s (state: %s, exit status: %s)".format( + containerId, + completedContainer.getState, + completedContainer.getExitStatus())) + // Hadoop 2.2.X added a ContainerExitStatus we should switch to use + // there are some exit status' we shouldn't necessarily count against us, but for + // now I think its ok as none of the containers are expected to exit + if (completedContainer.getExitStatus() != 0) { + logInfo("Container marked as failed: " + containerId) + numExecutorsFailed.incrementAndGet() + } + } + + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val hostOpt = allocatedContainerToHostMap.get(containerId) + assert(hostOpt.isDefined) + val host = hostOpt.get + + val containerSetOpt = allocatedHostToContainersMap.get(host) + assert(containerSetOpt.isDefined) + val containerSet = containerSetOpt.get + + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) + } + + allocatedContainerToHostMap.remove(containerId) + + // TODO: Move this part outside the synchronized block? + val rack = YarnSparkHadoopUtil.lookupRack(conf, host) + if (rack != null) { + val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 + if (rackCount > 0) { + allocatedRackCount.put(rack, rackCount) + } else { + allocatedRackCount.remove(rack) + } + } + } + } + } + logDebug(""" + Finished processing %d completed containers. + Current number of executors running: %d, + Released containers: %s + """.format( + completedContainers.size, + numExecutorsRunning.get(), + releasedContainers)) + } + } + + protected def allocatedContainersOnHost(host: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedHostToContainersMap.getOrElse(host, Set()).size + } + retval + } + + protected def allocatedContainersOnRack(rack: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedRackCount.getOrElse(rack, 0) + } + retval + } + + private def isResourceConstraintSatisfied(container: Container): Boolean = { + container.getResource.getMemory >= (executorMemory + memoryOverhead) + } + + // A simple method to copy the split info map. + private def generateNodeToWeight( + conf: Configuration, + input: collection.Map[String, collection.Set[SplitInfo]] + ): (Map[String, Int], Map[String, Int]) = { + + if (input == null) { + return (Map[String, Int](), Map[String, Int]()) + } + + val hostToCount = new HashMap[String, Int] + val rackToCount = new HashMap[String, Int] + + for ((host, splits) <- input) { + val hostCount = hostToCount.getOrElse(host, 0) + hostToCount.put(host, hostCount + splits.size) + + val rack = YarnSparkHadoopUtil.lookupRack(conf, host) + if (rack != null) { + val rackCount = rackToCount.getOrElse(host, 0) + rackToCount.put(host, rackCount + splits.size) + } + } + + (hostToCount.toMap, rackToCount.toMap) + } + + private def internalReleaseContainer(container: Container) = { + releasedContainers.put(container.getId(), true) + releaseContainer(container) + } + + /** + * Called to allocate containers in the cluster. + * + * @param count Number of containers to allocate. + * If zero, should still contact RM (as a heartbeat). + * @return Response to the allocation request. + */ + protected def allocateContainers(count: Int): YarnAllocateResponse + + /** Called to release a previously allocated container. */ + protected def releaseContainer(container: Container): Unit + + /** + * Defines the interface for an allocate response from the RM. This is needed since the alpha + * and stable interfaces differ here in ways that cannot be fixed using other routes. + */ + protected trait YarnAllocateResponse { + + def getAllocatedContainers(): JList[Container] + + def getAvailableResources(): Resource + + def getCompletedContainersStatuses(): JList[ContainerStatus] + + } + +} diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala new file mode 100644 index 0000000000000..ed65e56b3e413 --- /dev/null +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import scala.collection.{Map, Set} + +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.api.records._ + +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} +import org.apache.spark.scheduler.SplitInfo + +/** + * Interface that defines a Yarn RM client. Abstracts away Yarn version-specific functionality that + * is used by Spark's AM. + */ +trait YarnRMClient { + + /** + * Registers the application master with the RM. + * + * @param conf The Yarn configuration. + * @param sparkConf The Spark configuration. + * @param preferredNodeLocations Map with hints about where to allocate containers. + * @param uiAddress Address of the SparkUI. + * @param uiHistoryAddress Address of the application on the History Server. + */ + def register( + conf: YarnConfiguration, + sparkConf: SparkConf, + preferredNodeLocations: Map[String, Set[SplitInfo]], + uiAddress: String, + uiHistoryAddress: String, + securityMgr: SecurityManager): YarnAllocator + + /** + * Shuts down the AM. Guaranteed to only be called once. + * + * @param status The final status of the AM. + * @param diagnostics Diagnostics message to include in the final status. + */ + def shutdown(status: FinalApplicationStatus, diagnostics: String = ""): Unit + + /** Returns the attempt ID. */ + def getAttemptId(): ApplicationAttemptId + + /** Returns the RM's proxy host and port. */ + def getProxyHostAndPort(conf: YarnConfiguration): String + + /** Returns the maximum number of attempts to register the AM. */ + def getMaxRegAttempts(conf: YarnConfiguration): Int + +} diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 10aef5eb2486f..4a33e34c3bfc7 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -17,8 +17,11 @@ package org.apache.spark.deploy.yarn +import java.lang.{Boolean => JBoolean} +import java.util.{Collections, Set => JSet} import java.util.regex.Matcher import java.util.regex.Pattern +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.HashMap @@ -29,11 +32,13 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringInterner import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants +import org.apache.hadoop.yarn.api.records.ApplicationAccessType +import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.deploy.history.HistoryServer +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Contains util methods to interact with Hadoop from spark. @@ -49,7 +54,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems // Always create a new config, dont reuse yarnConf. - override def newConfiguration(): Configuration = new YarnConfiguration(new Configuration()) + override def newConfiguration(conf: SparkConf): Configuration = + new YarnConfiguration(super.newConfiguration(conf)) // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster override def addCredentials(conf: JobConf) { @@ -79,6 +85,21 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } object YarnSparkHadoopUtil { + // Additional memory overhead - in mb. + val DEFAULT_MEMORY_OVERHEAD = 384 + + val ANY_HOST = "*" + + // All RM requests are issued with same priority : we do not (yet) have any distinction between + // request types (like map/reduce in hadoop for example) + val RM_REQUEST_PRIORITY = 1 + + // Host to rack map - saved from allocation requests. We are expecting this not to change. + // Note that it is possible for this to change : and ResourceManager will indicate that to us via + // update response to allocate. But we are punting on handling that for now. + private val hostToRack = new ConcurrentHashMap[String, String]() + private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + def addToEnvironment( env: HashMap[String, String], variable: String, @@ -135,19 +156,6 @@ object YarnSparkHadoopUtil { } } - def getUIHistoryAddress(sc: SparkContext, conf: SparkConf) : String = { - val eventLogDir = sc.eventLogger match { - case Some(logger) => logger.getApplicationLogDir() - case None => "" - } - val historyServerAddress = conf.get("spark.yarn.historyServer.address", "") - if (historyServerAddress != "" && eventLogDir != "") { - historyServerAddress + HistoryServer.UI_PATH_PREFIX + s"/$eventLogDir" - } else { - "" - } - } - /** * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The @@ -173,4 +181,43 @@ object YarnSparkHadoopUtil { } } + private[spark] def lookupRack(conf: Configuration, host: String): String = { + if (!hostToRack.contains(host)) { + populateRackInfo(conf, host) + } + hostToRack.get(host) + } + + private[spark] def populateRackInfo(conf: Configuration, hostname: String) { + Utils.checkHost(hostname) + + if (!hostToRack.containsKey(hostname)) { + // If there are repeated failures to resolve, all to an ignore list. + val rackInfo = RackResolver.resolve(conf, hostname) + if (rackInfo != null && rackInfo.getNetworkLocation != null) { + val rack = rackInfo.getNetworkLocation + hostToRack.put(hostname, rack) + if (! rackToHostSet.containsKey(rack)) { + rackToHostSet.putIfAbsent(rack, + Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) + } + rackToHostSet.get(rack).add(hostname) + + // TODO(harvey): Figure out what this comment means... + // Since RackResolver caches, we are disabling this for now ... + } /* else { + // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... + hostToRack.put(hostname, null) + } */ + } + } + + private[spark] def getApplicationAclsForYarn(securityMgr: SecurityManager): + Map[ApplicationAccessType, String] = { + Map[ApplicationAccessType, String] ( + ApplicationAccessType.VIEW_APP -> securityMgr.getViewAcls, + ApplicationAccessType.MODIFY_APP -> securityMgr.getModifyAcls + ) + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 3474112ded5d7..254774a6b839e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -18,23 +18,18 @@ package org.apache.spark.scheduler.cluster import org.apache.spark._ -import org.apache.hadoop.conf.Configuration -import org.apache.spark.deploy.yarn.YarnAllocationHandler +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils /** - * - * This scheduler launches executors through Yarn - by calling into Client to launch ExecutorLauncher as AM. + * This scheduler launches executors through Yarn - by calling into Client to launch the Spark AM. */ -private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { - - def this(sc: SparkContext) = this(sc, new Configuration()) +private[spark] class YarnClientClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { val host = Utils.parseHostPort(hostPort)._1 - val retval = YarnAllocationHandler.lookupRack(conf, host) - if (retval != null) Some(retval) else None + Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host)) } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 833e249f9f612..6aa6475fe4a18 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import scala.collection.mutable.ArrayBuffer @@ -55,15 +55,11 @@ private[spark] class YarnClientSchedulerBackend( val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort - conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort) - conf.set("spark.driver.appUIHistoryAddress", YarnSparkHadoopUtil.getUIHistoryAddress(sc, conf)) + sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIHostPort) } val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ( - "--class", "notused", - "--jar", null, // The primary jar will be added dynamically in SparkContext. - "--args", hostport, - "--am-class", classOf[ExecutorLauncher].getName + "--args", hostport ) // process any optional arguments, given either as environment variables @@ -153,4 +149,7 @@ private[spark] class YarnClientSchedulerBackend( override def sufficientResourcesRegistered(): Boolean = { totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } + + override def applicationId(): Option[String] = Option(appId).map(_.toString()) + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 9aeca4a637d38..4157ff95c2794 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -18,21 +18,18 @@ package org.apache.spark.scheduler.cluster import org.apache.spark._ -import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils -import org.apache.hadoop.conf.Configuration /** - * - * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done + * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of + * ApplicationMaster, etc is done */ -private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { +private[spark] class YarnClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { logInfo("Created YarnClusterScheduler") - def this(sc: SparkContext) = this(sc, new Configuration()) - // Nothing else for now ... initialize application master : which needs a SparkContext to // determine how to allocate. // Note that only the first creation of a SparkContext influences (and ideally, there must be @@ -42,8 +39,7 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { val host = Utils.parseHostPort(hostPort)._1 - val retval = YarnAllocationHandler.lookupRack(conf, host) - if (retval != null) Some(retval) else None + Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host)) } override def postStartHook() { @@ -51,4 +47,10 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) super.postStartHook() logInfo("YarnClusterScheduler.postStartHook done") } + + override def stop() { + super.stop() + ApplicationMaster.sparkContextStopped(sc) + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 55665220a6f96..39436d0999663 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -28,7 +28,7 @@ private[spark] class YarnClusterSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { var totalExpectedExecutors = 0 - + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 } @@ -47,4 +47,7 @@ private[spark] class YarnClusterSchedulerBackend( override def sufficientResourcesRegistered(): Boolean = { totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } + + override def applicationId(): Option[String] = sc.getConf.getOption("spark.yarn.app.id") + } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 68cc2890f3a22..5480eca7c832c 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -238,9 +238,6 @@ class ClientBaseSuite extends FunSuite with Matchers { val sparkConf: SparkConf, val yarnConf: YarnConfiguration) extends ClientBase { - override def calculateAMMemory(newApp: GetNewApplicationResponse): Int = - throw new UnsupportedOperationException() - override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = throw new UnsupportedOperationException() diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 7650bd4396c12..2cc5abb3a890c 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -20,9 +20,13 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.{FunSuite, Matchers} -import org.apache.spark.Logging +import org.apache.hadoop.yarn.api.records.ApplicationAccessType + +import org.apache.spark.{Logging, SecurityManager, SparkConf} + class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { @@ -61,4 +65,87 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { } } + test("Yarn configuration override") { + val key = "yarn.nodemanager.hostname" + val default = new YarnConfiguration() + + val sparkConf = new SparkConf() + .set("spark.hadoop." + key, "someHostName") + val yarnConf = new YarnSparkHadoopUtil().newConfiguration(sparkConf) + + yarnConf.getClass() should be (classOf[YarnConfiguration]) + yarnConf.get(key) should not be default.get(key) + } + + + test("test getApplicationAclsForYarn acls on") { + + // spark acls on, just pick up default user + val sparkConf = new SparkConf() + sparkConf.set("spark.acls.enable", "true") + + val securityMgr = new SecurityManager(sparkConf) + val acls = YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr) + + val viewAcls = acls.get(ApplicationAccessType.VIEW_APP) + val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) + + viewAcls match { + case Some(vacls) => { + val aclSet = vacls.split(',').map(_.trim).toSet + assert(aclSet.contains(System.getProperty("user.name", "invalid"))) + } + case None => { + fail() + } + } + modifyAcls match { + case Some(macls) => { + val aclSet = macls.split(',').map(_.trim).toSet + assert(aclSet.contains(System.getProperty("user.name", "invalid"))) + } + case None => { + fail() + } + } + } + + test("test getApplicationAclsForYarn acls on and specify users") { + + // default spark acls are on and specify acls + val sparkConf = new SparkConf() + sparkConf.set("spark.acls.enable", "true") + sparkConf.set("spark.ui.view.acls", "user1,user2") + sparkConf.set("spark.modify.acls", "user3,user4") + + val securityMgr = new SecurityManager(sparkConf) + val acls = YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr) + + val viewAcls = acls.get(ApplicationAccessType.VIEW_APP) + val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) + + viewAcls match { + case Some(vacls) => { + val aclSet = vacls.split(',').map(_.trim).toSet + assert(aclSet.contains("user1")) + assert(aclSet.contains("user2")) + assert(aclSet.contains(System.getProperty("user.name", "invalid"))) + } + case None => { + fail() + } + } + modifyAcls match { + case Some(macls) => { + val aclSet = macls.split(',').map(_.trim).toSet + assert(aclSet.contains("user3")) + assert(aclSet.contains("user4")) + assert(aclSet.contains(System.getProperty("user.name", "invalid"))) + } + case None => { + fail() + } + } + + } } diff --git a/yarn/pom.xml b/yarn/pom.xml index 3faaf053634d6..7fcd7ee0d4547 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index b6c8456d06684..fd934b7726181 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala deleted file mode 100644 index 1c4005fd8e78e..0000000000000 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ /dev/null @@ -1,413 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn - -import java.io.IOException -import java.util.concurrent.CopyOnWriteArrayList -import java.util.concurrent.atomic.AtomicReference - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.util.ShutdownHookManager -import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.protocolrecords._ -import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.client.api.AMRMClient -import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.ConverterUtils -import org.apache.hadoop.yarn.webapp.util.WebAppUtils - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.{SignalLogger, Utils} - - -/** - * An application master that runs the user's driver program and allocates executors. - */ -class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, - sparkConf: SparkConf) extends Logging { - - def this(args: ApplicationMasterArguments, sparkConf: SparkConf) = - this(args, new Configuration(), sparkConf) - - def this(args: ApplicationMasterArguments) = this(args, new SparkConf()) - - private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - private var appAttemptId: ApplicationAttemptId = _ - private var userThread: Thread = _ - private val fs = FileSystem.get(yarnConf) - - private var yarnAllocator: YarnAllocationHandler = _ - private var isFinished: Boolean = false - private var uiAddress: String = _ - private var uiHistoryAddress: String = _ - private val maxAppAttempts: Int = conf.getInt( - YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) - private var isLastAMRetry: Boolean = true - private var amClient: AMRMClient[ContainerRequest] = _ - - // Default to numExecutors * 2, with minimum of 3 - private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) - - private var registered = false - - def run() { - // Set the web ui port to be ephemeral for yarn so we don't conflict with - // other spark processes running on the same box - System.setProperty("spark.ui.port", "0") - - // When running the AM, the Spark master is always "yarn-cluster" - System.setProperty("spark.master", "yarn-cluster") - - // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using. - ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) - - appAttemptId = ApplicationMaster.getApplicationAttemptId() - logInfo("ApplicationAttemptId: " + appAttemptId) - isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts - amClient = AMRMClient.createAMRMClient() - amClient.init(yarnConf) - amClient.start() - - // setup AmIpFilter for the SparkUI - do this before we start the UI - addAmIpFilter() - - ApplicationMaster.register(this) - - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserClass which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) - - // Start the user's JAR - userThread = startUserClass() - - // This a bit hacky, but we need to wait until the spark.driver.port property has - // been set by the Thread executing the user class. - waitForSparkContextInitialized() - - // Do this after Spark master is up and SparkContext is created so that we can register UI Url. - synchronized { - if (!isFinished) { - registerApplicationMaster() - registered = true - } - } - - // Allocate all containers - allocateExecutors() - - // Launch thread that will heartbeat to the RM so it won't think the app has died. - launchReporterThread() - - // Wait for the user class to finish - userThread.join() - - System.exit(0) - } - - // add the yarn amIpFilter that Yarn requires for properly securing the UI - private def addAmIpFilter() { - val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" - System.setProperty("spark.ui.filters", amFilter) - val proxy = WebAppUtils.getProxyHostAndPort(conf) - val parts : Array[String] = proxy.split(":") - val uriBase = "http://" + proxy + - System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) - - val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase - System.setProperty( - "spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) - } - - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { - logInfo("Registering the ApplicationMaster") - amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) - } - - private def startUserClass(): Thread = { - logInfo("Starting the user JAR in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) - val mainMethod = Class.forName( - args.userClass, - false, - Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) - val t = new Thread { - override def run() { - var succeeded = false - try { - // Copy - val mainArgs = new Array[String](args.userArgs.size) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) - mainMethod.invoke(null, mainArgs) - // Some apps have "System.exit(0)" at the end. The user thread will stop here unless - // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. - succeeded = true - } finally { - logDebug("Finishing main") - isLastAMRetry = true - if (succeeded) { - ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) - } else { - ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED) - } - } - } - } - t.setName("Driver") - t.start() - t - } - - // This needs to happen before allocateExecutors() - private def waitForSparkContextInitialized() { - logInfo("Waiting for Spark context initialization") - try { - var sparkContext: SparkContext = null - ApplicationMaster.sparkContextRef.synchronized { - var numTries = 0 - val waitTime = 10000L - val maxNumTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 10) - while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries - && !isFinished) { - logInfo("Waiting for Spark context initialization ... " + numTries) - numTries = numTries + 1 - ApplicationMaster.sparkContextRef.wait(waitTime) - } - sparkContext = ApplicationMaster.sparkContextRef.get() - assert(sparkContext != null || numTries >= maxNumTries) - - if (sparkContext != null) { - uiAddress = sparkContext.ui.appUIHostPort - uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf) - this.yarnAllocator = YarnAllocationHandler.newAllocator( - yarnConf, - amClient, - appAttemptId, - args, - sparkContext.preferredNodeLocationData, - sparkContext.getConf) - } else { - logWarning("Unable to retrieve SparkContext in spite of waiting for %d, maxNumTries = %d". - format(numTries * waitTime, maxNumTries)) - this.yarnAllocator = YarnAllocationHandler.newAllocator( - yarnConf, - amClient, - appAttemptId, - args, - sparkContext.getConf) - } - } - } - } - - private def allocateExecutors() { - try { - logInfo("Requesting" + args.numExecutors + " executors.") - // Wait until all containers have launched - yarnAllocator.addResourceRequests(args.numExecutors) - yarnAllocator.allocateResources() - // Exits the loop if the user thread exits. - - while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive - && !isFinished) { - checkNumExecutorsFailed() - allocateMissingExecutor() - yarnAllocator.allocateResources() - Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) - } - } - logInfo("All executors have launched.") - } - - private def allocateMissingExecutor() { - val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - - yarnAllocator.getNumPendingAllocate - if (missingExecutorCount > 0) { - logInfo("Allocating %d containers to make up for (potentially) lost containers". - format(missingExecutorCount)) - yarnAllocator.addResourceRequests(missingExecutorCount) - } - } - - private def checkNumExecutorsFailed() { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - logInfo("max number of executor failures reached") - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - // make sure to stop the user thread - val sparkContext = ApplicationMaster.sparkContextRef.get() - if (sparkContext != null) { - logInfo("Invoking sc stop from checkNumExecutorsFailed") - sparkContext.stop() - } else { - logError("sparkContext is null when should shutdown") - } - } - } - - private def launchReporterThread(): Thread = { - // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. - val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - - // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) - - // must be <= timeoutInterval / 2. - val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) - - val t = new Thread { - override def run() { - while (userThread.isAlive && !isFinished) { - checkNumExecutorsFailed() - allocateMissingExecutor() - logDebug("Sending progress") - yarnAllocator.allocateResources() - Thread.sleep(interval) - } - } - } - // Setting to daemon status, though this is usually not a good idea. - t.setDaemon(true) - t.start() - logInfo("Started progress reporter thread - heartbeat interval : " + interval) - t - } - - def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") { - synchronized { - if (isFinished) { - return - } - isFinished = true - - logInfo("Unregistering ApplicationMaster with " + status) - if (registered) { - amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) - } - } - } - - /** - * Clean up the staging directory. - */ - private def cleanupStagingDir() { - var stagingDirPath: Path = null - try { - val preserveFiles = sparkConf.get("spark.yarn.preserve.staging.files", "false").toBoolean - if (!preserveFiles) { - stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) - if (stagingDirPath == null) { - logError("Staging directory is null") - return - } - logInfo("Deleting staging directory " + stagingDirPath) - fs.delete(stagingDirPath, true) - } - } catch { - case ioe: IOException => - logError("Failed to cleanup staging dir " + stagingDirPath, ioe) - } - } - - // The shutdown hook that runs when a signal is received AND during normal close of the JVM. - class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable { - - def run() { - logInfo("AppMaster received a signal.") - // We need to clean up staging dir before HDFS is shut down - // make sure we don't delete it until this is the last AM - if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir() - } - } - -} - -object ApplicationMaster extends Logging { - // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be - // optimal as more containers are available. Might need to handle this better. - private val ALLOCATE_HEARTBEAT_INTERVAL = 100 - - private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() - - val sparkContextRef: AtomicReference[SparkContext] = - new AtomicReference[SparkContext](null) - - def register(master: ApplicationMaster) { - applicationMasters.add(master) - } - - /** - * Called from YarnClusterScheduler to notify the AM code that a SparkContext has been - * initialized in the user code. - */ - def sparkContextInitialized(sc: SparkContext): Boolean = { - var modified = false - sparkContextRef.synchronized { - modified = sparkContextRef.compareAndSet(null, sc) - sparkContextRef.notifyAll() - } - - // Add a shutdown hook - as a best effort in case users do not call sc.stop or do - // System.exit. - // Should not really have to do this, but it helps YARN to evict resources earlier. - // Not to mention, prevent the Client from declaring failure even though we exited properly. - // Note that this will unfortunately not properly clean up the staging files because it gets - // called too late, after the filesystem is already shutdown. - if (modified) { - Runtime.getRuntime().addShutdownHook(new Thread with Logging { - // This is not only logs, but also ensures that log system is initialized for this instance - // when we are actually 'run'-ing. - logInfo("Adding shutdown hook for context " + sc) - - override def run() { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - // Best case ... - for (master <- applicationMasters) { - master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) - } - } - }) - } - - // Wait for initialization to complete and at least 'some' nodes to get allocated. - modified - } - - def getApplicationAttemptId(): ApplicationAttemptId = { - val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - val appAttemptId = containerId.getApplicationAttemptId() - appAttemptId - } - - def main(argStrings: Array[String]) { - SignalLogger.register(log) - val args = new ApplicationMasterArguments(argStrings) - SparkHadoopUtil.get.runAsSparkUser { () => - new ApplicationMaster(args).run() - } - } -} diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 15f3c4f180ea3..82e45e3e7ad54 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SparkConf} - +import org.apache.spark.deploy.SparkHadoopUtil /** * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's stable API. @@ -40,7 +40,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa val yarnClient = YarnClient.createYarnClient def this(clientArgs: ClientArguments, spConf: SparkConf) = - this(clientArgs, new Configuration(), spConf) + this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) @@ -99,26 +99,8 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def logClusterResourceDetails() { val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics - logInfo("Got Cluster metric info from ResourceManager, number of NodeManagers: " + + logInfo("Got cluster metric info from ResourceManager, number of NodeManagers: " + clusterMetrics.getNumNodeManagers) - - val queueInfo: QueueInfo = yarnClient.getQueueInfo(args.amQueue) - logInfo( """Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s, - queueApplicationCount = %s, queueChildQueueCount = %s""".format( - queueInfo.getQueueName, - queueInfo.getCurrentCapacity, - queueInfo.getMaximumCapacity, - queueInfo.getApplications.size, - queueInfo.getChildQueues.size)) - } - - def calculateAMMemory(newApp: GetNewApplicationResponse) :Int = { - // TODO: Need a replacement for the following code to fix -Xmx? - // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() - // var amMemory = ((args.amMemory / minResMemory) * minResMemory) + - // ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - // memoryOverhead ) - args.amMemory } def setupSecurityToken(amContainer: ContainerLaunchContext) = { diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala deleted file mode 100644 index 45925f1fea005..0000000000000 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn - -import java.net.Socket -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.ApplicationConstants -import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.api.protocolrecords._ -import org.apache.hadoop.yarn.conf.YarnConfiguration -import akka.actor._ -import akka.remote._ -import org.apache.spark.{Logging, SecurityManager, SparkConf} -import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter -import org.apache.spark.scheduler.SplitInfo -import org.apache.hadoop.yarn.client.api.AMRMClient -import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.hadoop.yarn.webapp.util.WebAppUtils - -/** - * An application master that allocates executors on behalf of a driver that is running outside - * the cluster. - * - * This is used only in yarn-client mode. - */ -class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sparkConf: SparkConf) - extends Logging { - - def this(args: ApplicationMasterArguments, sparkConf: SparkConf) = - this(args, new Configuration(), sparkConf) - - def this(args: ApplicationMasterArguments) = this(args, new SparkConf()) - - private var appAttemptId: ApplicationAttemptId = _ - private var reporterThread: Thread = _ - private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - - private var yarnAllocator: YarnAllocationHandler = _ - private var driverClosed: Boolean = false - private var isFinished: Boolean = false - private var registered: Boolean = false - - private var amClient: AMRMClient[ContainerRequest] = _ - - // Default to numExecutors * 2, with minimum of 3 - private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) - - val securityManager = new SecurityManager(sparkConf) - val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityManager)._1 - var actor: ActorRef = _ - - // This actor just working as a monitor to watch on Driver Actor. - class MonitorActor(driverUrl: String) extends Actor { - - var driver: ActorSelection = _ - - override def preStart() { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } - - override def receive = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - driverClosed = true - case x: AddWebUIFilter => - logInfo(s"Add WebUI Filter. $x") - driver ! x - } - } - - def run() { - amClient = AMRMClient.createAMRMClient() - amClient.init(yarnConf) - amClient.start() - - appAttemptId = ApplicationMaster.getApplicationAttemptId() - synchronized { - if (!isFinished) { - registerApplicationMaster() - registered = true - } - } - - waitForSparkMaster() - addAmIpFilter() - - // Allocate all containers - allocateExecutors() - - // Launch a progress reporter thread, else app will get killed after expiration - // (def: 10mins) timeout ensure that progress is sent before - // YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. - - val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong - // must be <= timeoutInterval / 2. - val interval = math.min(timeoutInterval / 2, schedulerInterval) - - reporterThread = launchReporterThread(interval) - - - // Wait for the reporter thread to Finish. - reporterThread.join() - - finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) - actorSystem.shutdown() - - logInfo("Exited") - System.exit(0) - } - - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { - val appUIAddress = sparkConf.get("spark.driver.appUIAddress", "") - logInfo(s"Registering the ApplicationMaster with appUIAddress: $appUIAddress") - amClient.registerApplicationMaster(Utils.localHostName(), 0, appUIAddress) - } - - // add the yarn amIpFilter that Yarn requires for properly securing the UI - private def addAmIpFilter() { - val proxy = WebAppUtils.getProxyHostAndPort(conf) - val parts = proxy.split(":") - val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) - val uriBase = "http://" + proxy + proxyBase - val amFilter = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase - val amFilterName = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" - actor ! AddWebUIFilter(amFilterName, amFilter, proxyBase) - } - - private def waitForSparkMaster() { - logInfo("Waiting for Spark driver to be reachable.") - var driverUp = false - val hostport = args.userArgs(0) - val (driverHost, driverPort) = Utils.parseHostPort(hostport) - while(!driverUp) { - try { - val socket = new Socket(driverHost, driverPort) - socket.close() - logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) - driverUp = true - } catch { - case e: Exception => - logError("Failed to connect to driver at %s:%s, retrying ...". - format(driverHost, driverPort)) - Thread.sleep(100) - } - } - sparkConf.set("spark.driver.host", driverHost) - sparkConf.set("spark.driver.port", driverPort.toString) - - val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME) - - actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") - } - - - private def allocateExecutors() { - // TODO: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. - val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = - scala.collection.immutable.Map() - - yarnAllocator = YarnAllocationHandler.newAllocator( - yarnConf, - amClient, - appAttemptId, - args, - preferredNodeLocationData, - sparkConf) - - logInfo("Requesting " + args.numExecutors + " executors.") - // Wait until all containers have launched - yarnAllocator.addResourceRequests(args.numExecutors) - yarnAllocator.allocateResources() - while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed) && - !isFinished) { - checkNumExecutorsFailed() - allocateMissingExecutor() - yarnAllocator.allocateResources() - Thread.sleep(100) - } - - logInfo("All executors have launched.") - } - - private def allocateMissingExecutor() { - val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - - yarnAllocator.getNumPendingAllocate - if (missingExecutorCount > 0) { - logInfo("Allocating %d containers to make up for (potentially) lost containers". - format(missingExecutorCount)) - yarnAllocator.addResourceRequests(missingExecutorCount) - } - } - - private def checkNumExecutorsFailed() { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } - } - - private def launchReporterThread(_sleepTime: Long): Thread = { - val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime - - val t = new Thread { - override def run() { - while (!driverClosed && !isFinished) { - checkNumExecutorsFailed() - allocateMissingExecutor() - logDebug("Sending progress") - yarnAllocator.allocateResources() - Thread.sleep(sleepTime) - } - } - } - // setting to daemon status, though this is usually not a good idea. - t.setDaemon(true) - t.start() - logInfo("Started progress reporter thread - sleep time : " + sleepTime) - t - } - - def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") { - synchronized { - if (isFinished) { - return - } - logInfo("Unregistering ApplicationMaster with " + status) - if (registered) { - val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") - amClient.unregisterApplicationMaster(status, appMessage, trackingUrl) - } - isFinished = true - } - } - -} - -object ExecutorLauncher { - def main(argStrings: Array[String]) { - val args = new ApplicationMasterArguments(argStrings) - SparkHadoopUtil.get.runAsSparkUser { () => - new ExecutorLauncher(args).run() - } - } -} diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 07ba0a4b30bd7..833be12982e71 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SecurityManager, SparkConf, Logging} class ExecutorRunnable( @@ -46,7 +46,8 @@ class ExecutorRunnable( slaveId: String, hostname: String, executorMemory: Int, - executorCores: Int) + executorCores: Int, + securityMgr: SecurityManager) extends Runnable with ExecutorRunnableUtil with Logging { var rpc: YarnRPC = YarnRPC.create(conf) @@ -85,6 +86,8 @@ class ExecutorRunnable( logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) + ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // Send the start request to the ContainerManager nmClient.startContainer(container, ctx) } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 29ccec2adcac3..5438f151ac0ad 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -17,374 +17,46 @@ package org.apache.spark.deploy.yarn -import java.lang.{Boolean => JBoolean} -import java.util.{Collections, Set => JSet} -import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.scheduler.SplitInfo import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.ApplicationMasterProtocol -import org.apache.hadoop.yarn.api.records.ApplicationAttemptId -import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus} -import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest} -import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.hadoop.yarn.util.{RackResolver, Records} - - -object AllocationType extends Enumeration { - type AllocationType = Value - val HOST, RACK, ANY = Value -} - -// TODO: -// Too many params. -// Needs to be mt-safe -// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should -// make it more proactive and decoupled. - -// Note that right now, we assume all node asks as uniform in terms of capabilities and priority -// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for -// more info on how we are requesting for containers. +import org.apache.hadoop.yarn.util.Records /** * Acquires resources for executors from a ResourceManager and launches executors in new containers. */ private[yarn] class YarnAllocationHandler( - val conf: Configuration, - val amClient: AMRMClient[ContainerRequest], - val appAttemptId: ApplicationAttemptId, - val maxExecutors: Int, - val executorMemory: Int, - val executorCores: Int, - val preferredHostToCount: Map[String, Int], - val preferredRackToCount: Map[String, Int], - val sparkConf: SparkConf) - extends Logging { - // These three are locked on allocatedHostToContainersMap. Complementary data structures - // allocatedHostToContainersMap : containers which are running : host, Set - // allocatedContainerToHostMap: container to host mapping. - private val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]]() - - private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() - - // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an - // allocated node) - // As with the two data structures above, tightly coupled with them, and to be locked on - // allocatedHostToContainersMap - private val allocatedRackCount = new HashMap[String, Int]() - - // Containers which have been released. - private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() - // Containers to be released in next request to RM - private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] - - // Additional memory overhead - in mb. - private def memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnAllocationHandler.MEMORY_OVERHEAD) - - // Number of container requests that have been sent to, but not yet allocated by the - // ApplicationMaster. - private val numPendingAllocate = new AtomicInteger() - private val numExecutorsRunning = new AtomicInteger() - // Used to generate a unique id per executor - private val executorIdCounter = new AtomicInteger() - private val lastResponseId = new AtomicInteger() - private val numExecutorsFailed = new AtomicInteger() - - def getNumPendingAllocate: Int = numPendingAllocate.intValue - - def getNumExecutorsRunning: Int = numExecutorsRunning.intValue - - def getNumExecutorsFailed: Int = numExecutorsFailed.intValue + conf: Configuration, + sparkConf: SparkConf, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + preferredNodes: collection.Map[String, collection.Set[SplitInfo]], + securityMgr: SecurityManager) + extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { - def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + memoryOverhead) + override protected def releaseContainer(container: Container) = { + amClient.releaseAssignedContainer(container.getId()) } - def releaseContainer(container: Container) { - val containerId = container.getId - pendingReleaseContainers.put(containerId, true) - amClient.releaseAssignedContainer(containerId) - } + override protected def allocateContainers(count: Int): YarnAllocateResponse = { + addResourceRequests(count) - def allocateResources() { // We have already set the container request. Poll the ResourceManager for a response. // This doubles as a heartbeat if there are no pending container requests. val progressIndicator = 0.1f - val allocateResponse = amClient.allocate(progressIndicator) - - val allocatedContainers = allocateResponse.getAllocatedContainers() - if (allocatedContainers.size > 0) { - var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) - - if (numPendingAllocateNow < 0) { - numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) - } - - logDebug(""" - Allocated containers: %d - Current executor count: %d - Containers released: %s - Containers to-be-released: %s - Cluster resources: %s - """.format( - allocatedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers, - allocateResponse.getAvailableResources)) - - val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (container <- allocatedContainers) { - if (isResourceConstraintSatisfied(container)) { - // Add the accepted `container` to the host's list of already accepted, - // allocated containers - val host = container.getNodeId.getHost - val containersForHost = hostToContainers.getOrElseUpdate(host, - new ArrayBuffer[Container]()) - containersForHost += container - } else { - // Release container, since it doesn't satisfy resource constraints. - releaseContainer(container) - } - } - - // Find the appropriate containers to use. - // TODO: Cleanup this group-by... - val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (candidateHost <- hostToContainers.keySet) { - val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) - val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) - - val remainingContainersOpt = hostToContainers.get(candidateHost) - assert(remainingContainersOpt.isDefined) - var remainingContainers = remainingContainersOpt.get - - if (requiredHostCount >= remainingContainers.size) { - // Since we have <= required containers, add all remaining containers to - // `dataLocalContainers`. - dataLocalContainers.put(candidateHost, remainingContainers) - // There are no more free containers remaining. - remainingContainers = null - } else if (requiredHostCount > 0) { - // Container list has more containers than we need for data locality. - // Split the list into two: one based on the data local container count, - // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining - // containers. - val (dataLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredHostCount) - dataLocalContainers.put(candidateHost, dataLocal) - - // Invariant: remainingContainers == remaining - - // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. - // Add each container in `remaining` to list of containers to release. If we have an - // insufficient number of containers, then the next allocation cycle will reallocate - // (but won't treat it as data local). - // TODO(harvey): Rephrase this comment some more. - for (container <- remaining) releaseContainer(container) - remainingContainers = null - } - - // For rack local containers - if (remainingContainers != null) { - val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) - if (rack != null) { - val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) - val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - - rackLocalContainers.getOrElse(rack, List()).size - - if (requiredRackCount >= remainingContainers.size) { - // Add all remaining containers to to `dataLocalContainers`. - dataLocalContainers.put(rack, remainingContainers) - remainingContainers = null - } else if (requiredRackCount > 0) { - // Container list has more containers that we need for data locality. - // Split the list into two: one based on the data local container count, - // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining - // containers. - val (rackLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredRackCount) - val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, - new ArrayBuffer[Container]()) - - existingRackLocal ++= rackLocal - - remainingContainers = remaining - } - } - } - - if (remainingContainers != null) { - // Not all containers have been consumed - add them to the list of off-rack containers. - offRackContainers.put(candidateHost, remainingContainers) - } - } - - // Now that we have split the containers into various groups, go through them in order: - // first host-local, then rack-local, and finally off-rack. - // Note that the list we create below tries to ensure that not all containers end up within - // a host if there is a sufficiently large number of hosts/containers. - val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) - - // Run each of the allocated containers. - for (container <- allocatedContainersToProcess) { - val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() - val executorHostname = container.getNodeId.getHost - val containerId = container.getId - - val executorMemoryOverhead = (executorMemory + memoryOverhead) - assert(container.getResource.getMemory >= executorMemoryOverhead) - - if (numExecutorsRunningNow > maxExecutors) { - logInfo("""Ignoring container %s at host %s, since we already have the required number of - containers for it.""".format(containerId, executorHostname)) - releaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - val executorId = executorIdCounter.incrementAndGet().toString - val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - - logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) - - // To be safe, remove the container from `pendingReleaseContainers`. - pendingReleaseContainers.remove(containerId) - - val rack = YarnAllocationHandler.lookupRack(conf, executorHostname) - allocatedHostToContainersMap.synchronized { - val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, - new HashSet[ContainerId]()) - - containerSet += containerId - allocatedContainerToHostMap.put(containerId, executorHostname) - - if (rack != null) { - allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) - } - } - logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( - driverUrl, executorHostname)) - val executorRunnable = new ExecutorRunnable( - container, - conf, - sparkConf, - driverUrl, - executorId, - executorHostname, - executorMemory, - executorCores) - new Thread(executorRunnable).start() - } - } - logDebug(""" - Finished allocating %s containers (from %s originally). - Current number of executors running: %d, - releasedContainerList: %s, - pendingReleaseContainers: %s - """.format( - allocatedContainersToProcess, - allocatedContainers, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers)) - } - - val completedContainers = allocateResponse.getCompletedContainersStatuses() - if (completedContainers.size > 0) { - logDebug("Completed %d containers".format(completedContainers.size)) - - for (completedContainer <- completedContainers) { - val containerId = completedContainer.getContainerId - - if (pendingReleaseContainers.containsKey(containerId)) { - // YarnAllocationHandler already marked the container for release, so remove it from - // `pendingReleaseContainers`. - pendingReleaseContainers.remove(containerId) - } else { - // Decrement the number of executors running. The next iteration of - // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() - logInfo("Completed container %s (state: %s, exit status: %s)".format( - containerId, - completedContainer.getState, - completedContainer.getExitStatus())) - // Hadoop 2.2.X added a ContainerExitStatus we should switch to use - // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus() != 0) { - logInfo("Container marked as failed: " + containerId) - numExecutorsFailed.incrementAndGet() - } - } - - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val hostOpt = allocatedContainerToHostMap.get(containerId) - assert(hostOpt.isDefined) - val host = hostOpt.get - - val containerSetOpt = allocatedHostToContainersMap.get(host) - assert(containerSetOpt.isDefined) - val containerSet = containerSetOpt.get - - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) - - // TODO: Move this part outside the synchronized block? - val rack = YarnAllocationHandler.lookupRack(conf, host) - if (rack != null) { - val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 - if (rackCount > 0) { - allocatedRackCount.put(rack, rackCount) - } else { - allocatedRackCount.remove(rack) - } - } - } - } - } - logDebug(""" - Finished processing %d completed containers. - Current number of executors running: %d, - releasedContainerList: %s, - pendingReleaseContainers: %s - """.format( - completedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers)) - } + new StableAllocateResponse(amClient.allocate(progressIndicator)) } - def createRackResourceRequests( + private def createRackResourceRequests( hostContainers: ArrayBuffer[ContainerRequest] ): ArrayBuffer[ContainerRequest] = { // Generate modified racks and new set of hosts under it before issuing requests. @@ -392,9 +64,9 @@ private[yarn] class YarnAllocationHandler( for (container <- hostContainers) { val candidateHost = container.getNodes.last - assert(YarnAllocationHandler.ANY_HOST != candidateHost) + assert(YarnSparkHadoopUtil.ANY_HOST != candidateHost) - val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) if (rack != null) { var count = rackToCounts.getOrElse(rack, 0) count += 1 @@ -408,40 +80,26 @@ private[yarn] class YarnAllocationHandler( AllocationType.RACK, rack, count, - YarnAllocationHandler.PRIORITY) + YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) } requestedContainers } - def allocatedContainersOnHost(host: String): Int = { - var retval = 0 - allocatedHostToContainersMap.synchronized { - retval = allocatedHostToContainersMap.getOrElse(host, Set()).size - } - retval - } - - def allocatedContainersOnRack(rack: String): Int = { - var retval = 0 - allocatedHostToContainersMap.synchronized { - retval = allocatedRackCount.getOrElse(rack, 0) - } - retval - } - - def addResourceRequests(numExecutors: Int) { + private def addResourceRequests(numExecutors: Int) { val containerRequests: List[ContainerRequest] = - if (numExecutors <= 0 || preferredHostToCount.isEmpty) { - logDebug("numExecutors: " + numExecutors + ", host preferences: " + - preferredHostToCount.isEmpty) + if (numExecutors <= 0) { + logDebug("numExecutors: " + numExecutors) + List() + } else if (preferredHostToCount.isEmpty) { + logDebug("host preferences is empty") createResourceRequests( AllocationType.ANY, resource = null, numExecutors, - YarnAllocationHandler.PRIORITY).toList + YarnSparkHadoopUtil.RM_REQUEST_PRIORITY).toList } else { - // Request for all hosts in preferred nodes and for numExecutors - + // Request for all hosts in preferred nodes and for numExecutors - // candidates.size, request by default allocation policy. val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size) for ((candidateHost, candidateCount) <- preferredHostToCount) { @@ -452,7 +110,7 @@ private[yarn] class YarnAllocationHandler( AllocationType.HOST, candidateHost, requiredCount, - YarnAllocationHandler.PRIORITY) + YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) } } val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests( @@ -462,7 +120,7 @@ private[yarn] class YarnAllocationHandler( AllocationType.ANY, resource = null, numExecutors, - YarnAllocationHandler.PRIORITY) + YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) val containerRequestBuffer = new ArrayBuffer[ContainerRequest]( hostContainerRequests.size + rackContainerRequests.size() + anyContainerRequests.size) @@ -477,15 +135,6 @@ private[yarn] class YarnAllocationHandler( amClient.addContainerRequest(request) } - if (numExecutors > 0) { - numPendingAllocate.addAndGet(numExecutors) - logInfo("Will Allocate %d executor containers, each with %d memory".format( - numExecutors, - (executorMemory + memoryOverhead))) - } else { - logDebug("Empty allocation request ...") - } - for (request <- containerRequests) { val nodes = request.getNodes var hostStr = if (nodes == null || nodes.isEmpty) { @@ -511,7 +160,7 @@ private[yarn] class YarnAllocationHandler( // There must be a third request, which is ANY. That will be specially handled. requestType match { case AllocationType.HOST => { - assert(YarnAllocationHandler.ANY_HOST != resource) + assert(YarnSparkHadoopUtil.ANY_HOST != resource) val hostname = resource val nodeLocal = constructContainerRequests( Array(hostname), @@ -520,7 +169,7 @@ private[yarn] class YarnAllocationHandler( priority) // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler. - YarnAllocationHandler.populateRackInfo(conf, hostname) + YarnSparkHadoopUtil.populateRackInfo(conf, hostname) nodeLocal } case AllocationType.RACK => { @@ -553,152 +202,11 @@ private[yarn] class YarnAllocationHandler( } requests } -} - -object YarnAllocationHandler { - - val ANY_HOST = "*" - // All requests are issued with same priority : we do not (yet) have any distinction between - // request types (like map/reduce in hadoop for example) - val PRIORITY = 1 - - // Additional memory overhead - in mb. - val MEMORY_OVERHEAD = 384 - - // Host to rack map - saved from allocation requests. We are expecting this not to change. - // Note that it is possible for this to change : and ResurceManager will indicate that to us via - // update response to allocate. But we are punting on handling that for now. - private val hostToRack = new ConcurrentHashMap[String, String]() - private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() - - - def newAllocator( - conf: Configuration, - amClient: AMRMClient[ContainerRequest], - appAttemptId: ApplicationAttemptId, - args: ApplicationMasterArguments, - sparkConf: SparkConf - ): YarnAllocationHandler = { - new YarnAllocationHandler( - conf, - amClient, - appAttemptId, - args.numExecutors, - args.executorMemory, - args.executorCores, - Map[String, Int](), - Map[String, Int](), - sparkConf) - } - - def newAllocator( - conf: Configuration, - amClient: AMRMClient[ContainerRequest], - appAttemptId: ApplicationAttemptId, - args: ApplicationMasterArguments, - map: collection.Map[String, - collection.Set[SplitInfo]], - sparkConf: SparkConf - ): YarnAllocationHandler = { - val (hostToSplitCount, rackToSplitCount) = generateNodeToWeight(conf, map) - new YarnAllocationHandler( - conf, - amClient, - appAttemptId, - args.numExecutors, - args.executorMemory, - args.executorCores, - hostToSplitCount, - rackToSplitCount, - sparkConf) - } - - def newAllocator( - conf: Configuration, - amClient: AMRMClient[ContainerRequest], - appAttemptId: ApplicationAttemptId, - maxExecutors: Int, - executorMemory: Int, - executorCores: Int, - map: collection.Map[String, collection.Set[SplitInfo]], - sparkConf: SparkConf - ): YarnAllocationHandler = { - val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) - new YarnAllocationHandler( - conf, - amClient, - appAttemptId, - maxExecutors, - executorMemory, - executorCores, - hostToCount, - rackToCount, - sparkConf) - } - - // A simple method to copy the split info map. - private def generateNodeToWeight( - conf: Configuration, - input: collection.Map[String, collection.Set[SplitInfo]] - ): (Map[String, Int], Map[String, Int]) = { - - if (input == null) { - return (Map[String, Int](), Map[String, Int]()) - } - - val hostToCount = new HashMap[String, Int] - val rackToCount = new HashMap[String, Int] - - for ((host, splits) <- input) { - val hostCount = hostToCount.getOrElse(host, 0) - hostToCount.put(host, hostCount + splits.size) - - val rack = lookupRack(conf, host) - if (rack != null){ - val rackCount = rackToCount.getOrElse(host, 0) - rackToCount.put(host, rackCount + splits.size) - } - } - - (hostToCount.toMap, rackToCount.toMap) - } - def lookupRack(conf: Configuration, host: String): String = { - if (!hostToRack.contains(host)) { - populateRackInfo(conf, host) - } - hostToRack.get(host) + private class StableAllocateResponse(response: AllocateResponse) extends YarnAllocateResponse { + override def getAllocatedContainers() = response.getAllocatedContainers() + override def getAvailableResources() = response.getAvailableResources() + override def getCompletedContainersStatuses() = response.getCompletedContainersStatuses() } - def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { - Option(rackToHostSet.get(rack)).map { set => - val convertedSet: collection.mutable.Set[String] = set - // TODO: Better way to get a Set[String] from JSet. - convertedSet.toSet - } - } - - def populateRackInfo(conf: Configuration, hostname: String) { - Utils.checkHost(hostname) - - if (!hostToRack.containsKey(hostname)) { - // If there are repeated failures to resolve, all to an ignore list. - val rackInfo = RackResolver.resolve(conf, hostname) - if (rackInfo != null && rackInfo.getNetworkLocation != null) { - val rack = rackInfo.getNetworkLocation - hostToRack.put(hostname, rack) - if (! rackToHostSet.containsKey(rack)) { - rackToHostSet.putIfAbsent(rack, - Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) - } - rackToHostSet.get(rack).add(hostname) - - // TODO(harvey): Figure out what this comment means... - // Since RackResolver caches, we are disabling this for now ... - } /* else { - // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... - hostToRack.put(hostname, null) - } */ - } - } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala new file mode 100644 index 0000000000000..54bc6b14c44ce --- /dev/null +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import scala.collection.{Map, Set} + +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.util.ConverterUtils +import org.apache.hadoop.yarn.webapp.util.WebAppUtils + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.scheduler.SplitInfo +import org.apache.spark.util.Utils + + +/** + * YarnRMClient implementation for the Yarn stable API. + */ +private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMClient with Logging { + + private var amClient: AMRMClient[ContainerRequest] = _ + private var uiHistoryAddress: String = _ + + override def register( + conf: YarnConfiguration, + sparkConf: SparkConf, + preferredNodeLocations: Map[String, Set[SplitInfo]], + uiAddress: String, + uiHistoryAddress: String, + securityMgr: SecurityManager) = { + amClient = AMRMClient.createAMRMClient() + amClient.init(conf) + amClient.start() + this.uiHistoryAddress = uiHistoryAddress + + logInfo("Registering the ApplicationMaster") + amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + new YarnAllocationHandler(conf, sparkConf, amClient, getAttemptId(), args, + preferredNodeLocations, securityMgr) + } + + override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") = + amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) + + override def getAttemptId() = { + val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + appAttemptId + } + + override def getProxyHostAndPort(conf: YarnConfiguration) = WebAppUtils.getProxyHostAndPort(conf) + + override def getMaxRegAttempts(conf: YarnConfiguration) = + conf.getInt(YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) + +}